package main
import (
"context"
"log"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/bsonx"
)
type Book struct {
Id primitive.ObjectID `bson:"_id"`
Name string
Category string
Weight int
Author AuthorInfo
}
type AuthorInfo struct {
Name string
Country string
}
const (
categoryComputer = "计算机"
categorySciFi = "科幻"
countryChina = "中国"
countryAmerica = "美国"
)
var (
books = []interface{}{
&Book{
Id: primitive.NewObjectID(),
Name: "深入理解计算机操作系统",
Category: categoryComputer,
Weight: 1,
Author: AuthorInfo{
Name: "兰德尔 E.布莱恩特",
Country: countryAmerica,
},
},
&Book{
Id: primitive.NewObjectID(),
Name: "深入理解Linux内核",
Category: categoryComputer,
Weight: 1,
Author: AuthorInfo{
Name: "博韦,西斯特",
Country: countryAmerica,
},
},
&Book{
Id: primitive.NewObjectID(),
Name: "三体",
Category: categorySciFi,
Weight: 1,
Author: AuthorInfo{
Name: "刘慈欣",
Country: countryChina,
},
},
}
)
func main() {
log.SetFlags(log.Llongfile | log.LstdFlags)
opts := options.Client().ApplyURI("mongodb://localhost:27017")
// 连接数据库
client, err := mongo.Connect(context.Background(), opts)
if err != nil {
log.Fatal(err)
}
// 判断服务是不是可用
if err = client.Ping(context.Background(), readpref.Primary()); err != nil {
log.Fatal(err)
}
// 获取数据库和集合
collection := client.Database("mydb").Collection("book")
// 清空文档
err = collection.Drop(context.Background())
if err != nil {
log.Fatal(err)
}
// 设置索引
idx := mongo.IndexModel{
Keys: bsonx.Doc{{"name", bsonx.Int32(1)}},
Options: options.Index().SetUnique(true),
}
idxRet, err := collection.Indexes().CreateOne(context.Background(), idx)
if err != nil {
log.Fatal(err)
}
log.Println("collection.Indexes().CreateOne:", idxRet)
// 插入一条数据
insertOneResult, err := collection.InsertOne(context.Background(), books[0])
if err != nil {
log.Fatal(err)
}
log.Println("collection.InsertOne: ", insertOneResult.InsertedID)
// 插入多条数据
insertManyResult, err := collection.InsertMany(context.Background(), books[1:])
if err != nil {
log.Fatal(err)
}
log.Println("collection.InsertMany: ", insertManyResult.InsertedIDs)
// 获取数据总数
count, err := collection.CountDocuments(context.Background(), bson.D{})
if err != nil {
log.Fatal(count)
}
log.Println("collection.CountDocuments:", count)
// 查询单条数据
var one Book
err = collection.FindOne(context.Background(), bson.M{"name": "三体"}).Decode(&one)
if err != nil {
log.Fatal(err)
}
log.Println("collection.FindOne: ", one)
// 查询多条数据(方式一)
cur, err := collection.Find(context.Background(), bson.D{})
if err != nil {
log.Fatal(err)
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
var all []*Book
err = cur.All(context.Background(), &all)
if err != nil {
log.Fatal(err)
}
cur.Close(context.Background())
log.Println("collection.Find curl.All: ", all)
for _, one := range all {
log.Println(one)
}
// 查询多条数据(方式二)
cur, err = collection.Find(context.Background(), bson.D{})
if err != nil {
log.Fatal(err)
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
for cur.Next(context.Background()) {
var b Book
if err = cur.Decode(&b); err != nil {
log.Fatal(err)
}
log.Println("collection.Find cur.Next:", b)
}
cur.Close(context.Background())
// 模糊查询
cur, err = collection.Find(context.Background(), bson.M{"name": primitive.Regex{Pattern: "深入"}})
if err != nil {
log.Fatal(err)
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
for cur.Next(context.Background()) {
var b Book
if err = cur.Decode(&b); err != nil {
log.Fatal(err)
}
log.Println("collection.Find name=primitive.Regex{深入}: ", b)
}
cur.Close(context.Background())
// 二级结构体查询
cur, err = collection.Find(context.Background(), bson.M{"author.country": countryChina})
// cur, err = collection.Find(context.Background(), bson.D{bson.E{"author.country", countryChina}})
if err != nil {
log.Fatal(err)
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
for cur.Next(context.Background()) {
var b Book
if err = cur.Decode(&b); err != nil {
log.Fatal(err)
}
log.Println("collection.Find author.country=", countryChina, ":", b)
}
cur.Close(context.Background())
// 修改一条数据
b1 := books[0].(*Book)
b1.Weight = 2
update := bson.M{"$set": b1}
updateResult, err := collection.UpdateOne(context.Background(), bson.M{"name": b1.Name}, update)
if err != nil {
log.Fatal(err)
}
log.Println("collection.UpdateOne:", updateResult)
// 修改一条数据,如果不存在则插入
new := &Book{
Id: primitive.NewObjectID(),
Name: "球状闪电",
Category: categorySciFi,
Author: AuthorInfo{
Name: "刘慈欣",
Country: countryChina,
},
}
update = bson.M{"$set": new}
updateOpts := options.Update().SetUpsert(true)
updateResult, err = collection.UpdateOne(context.Background(), bson.M{"_id": new.Id}, update, updateOpts)
if err != nil {
log.Fatal(err)
}
log.Println("collection.UpdateOne:", updateResult)
// 删除一条数据
deleteResult, err := collection.DeleteOne(context.Background(), bson.M{"_id": new.Id})
if err != nil {
log.Fatal(err)
}
log.Println("collection.DeleteOne:", deleteResult)
}
func createUniqueIndex(collection string, keys ...string) {
db := DB.Mongo.Database(setting.DatabaseSetting.DBName).Collection(collection)
opts := options.CreateIndexes().SetMaxTime(10 * time.Second)
indexView := db.Indexes()
keysDoc := bsonx.Doc{}
// 复合索引
for _, key := range keys {
if strings.HasPrefix(key, "-") {
keysDoc = keysDoc.Append(strings.TrimLeft(key, "-"), bsonx.Int32(-1))
} else {
keysDoc = keysDoc.Append(key, bsonx.Int32(1))
}
}
// 创建索引
result, err := indexView.CreateOne(
context.Background(),
mongo.IndexModel{
Keys: keysDoc,
Options: options.Index().SetUnique(true),
},
opts,
)
if result == "" || err != nil {
Logger.Error("EnsureIndex error", zap.String("error", err.Error()))
}
}
func FindNodes() ([]DBNode, error) {
var nodes []DBNode
c := Connect(setting.DatabaseSetting.DBName, superNodeC)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
opts := options.Find().SetSort(bsonx.Doc{{"vote_count", bsonx.Int32(-1)}})
cursor, err := c.Find(ctx, M{}, opts)
if err != nil {
return nil, err
}
for cursor.Next(context.Background()) {
var node DBNode
if err = cursor.Decode(&node); err != nil {
return nil, err
} else {
nodes = append(nodes, node)
}
}
return nodes, nil
}
func Insert(db, collection string, docs ...interface{}) (*mongo.InsertManyResult, error) {
c := Connect(db, collection)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return c.InsertMany(ctx, docs)
}
func InsertNode(node DBNode) error {
_, err := Insert(setting.DatabaseSetting.DBName, superNodeC, node)
return err
}
func Update(db, collection string, query, update interface{}) (*mongo.UpdateResult, error) {
c := Connect(db, collection)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
opts := options.Update().SetUpsert(true)
return c.UpdateOne(ctx, query, update,opts)
}
func Remove(db, collection string, query interface{}) (*mongo.DeleteResult, error) {
c := Connect(db, collection)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return c.DeleteOne(ctx, query)
}
func RemoveNode(pubKey string) error {
findM := M{"pub_key": pubKey}
_, err := Remove(setting.DatabaseSetting.DBName, superNodeC, findM)
return err
}
- 需要使用SessionContext
- 所有的修改之前需要查询的,请都使用SessionContext(即都使用事务)
因为多处使用,所以封装了一个方法;
在这个方法中需要实现的方法是Exec的operator
type DBTransaction struct {
Commit func(mongo.SessionContext) error
Run func(mongo.SessionContext, func(mongo.SessionContext, DBTransaction) error) error
Logger *logging.Logger
}
func NewDBTransaction(logger *logging.Logger) *DBTransaction {
var dbTransaction = &DBTransaction{}
dbTransaction.SetLogger(logger)
dbTransaction.SetRun()
dbTransaction.SetCommit()
return dbTransaction
}
func (d *DBTransaction) SetCommit() {
d.Commit = func(sctx mongo.SessionContext) error {
err := sctx.CommitTransaction(sctx)
switch e := err.(type) {
case nil:
d.Logger.Info("Transaction committed.")
return nil
default:
d.Logger.Error("Error during commit...")
return e
}
}
}
func (d *DBTransaction) SetRun() {
d.Run = func(sctx mongo.SessionContext, txnFn func(mongo.SessionContext, DBTransaction) error) error {
err := txnFn(sctx, *d) // Performs transaction.
if err == nil {
return nil
}
d.Logger.Error("Transaction aborted. Caught exception during transaction.",
zap.String("error", err.Error()))
return err
}
}
func (d *DBTransaction) SetLogger(logger *logging.Logger) {
d.Logger = logger
}
func (d *DBTransaction) Exec(mongoClient *mongo.Client, operator func(mongo.SessionContext, DBTransaction) error) error {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
defer cancel()
return mongoClient.UseSessionWithOptions(
ctx, options.Session().SetDefaultReadPreference(readpref.Primary()),
func(sctx mongo.SessionContext) error {
return d.Run(sctx, operator)
},
)
}
//具体调用
func SyncBlockData(node models.DBNode) error {
dbTransaction := db_session_service.NewDBTransaction(Logger)
// Updates two collections in a transaction.
updateEmployeeInfo := func(sctx mongo.SessionContext, d db_session_service.DBTransaction) error {
err := sctx.StartTransaction(options.Transaction().
SetReadConcern(readconcern.Snapshot()).
SetWriteConcern(writeconcern.New(writeconcern.WMajority())),
)
if err != nil {
return err
}
err = models.InsertNodeWithSession(sctx, node)
if err != nil {
_ = sctx.AbortTransaction(sctx)
d.Logger.Info("caught exception during transaction, aborting.")
return err
}
return d.Commit(sctx)
}
return dbTransaction.Exec(models.DB.Mongo, updateEmployeeInfo)
}
func Find(database *mongo.Database,collection string,limit,index int64) (data []map[string]interface,err error){
ctx, cannel := context.WithTimeout(context.Background(), time.Minute)
defer cannel()
var findoptions *options.FindOptions
if limit > 0 {
findoptions = &options.FindOptions{}
findoptions.SetLimit(limit)
findoptions.SetSkip(limit * index)
}
cur, err := database.Collection(collection).Find(ctx, bson.M{}, findoptions)
if err != nil {
return nil, err
}
defer cur.Close(context.Background())
err = cur.All(context.Background(), &data)
return
}
nice job