Created
May 4, 2023 16:19
-
-
Save lesismal/6b397b12bb1e328395873a2c35a71af0 to your computer and use it in GitHub Desktop.
TaskTree.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"log" | |
"sync" | |
"sync/atomic" | |
"time" | |
) | |
type Node struct { | |
F func() error | |
Nodes []*Node | |
selfCnt int32 | |
parentCnt *int32 | |
parentWg *sync.WaitGroup | |
parentFunc func() error | |
rollback func(error) | |
} | |
func (node *Node) doNodes(parentWg *sync.WaitGroup, parentCnt *int32, parentFunc func() error, rollback func(error)) { | |
node.selfCnt = int32(len(node.Nodes)) | |
node.parentWg = parentWg | |
node.parentCnt = parentCnt | |
node.parentFunc = parentFunc | |
node.rollback = rollback | |
if node.selfCnt > 0 { | |
for _, v := range node.Nodes { | |
v.doNodes(nil, &node.selfCnt, node.doSelf, rollback) | |
} | |
return | |
} | |
go node.doSelf() | |
} | |
func (node *Node) doSelf() error { | |
defer func() { | |
if node.parentWg != nil { | |
node.parentWg.Done() | |
} else if atomic.AddInt32(node.parentCnt, -1) == 0 && node.parentFunc != nil { | |
if err := node.parentFunc(); err != nil && node.rollback != nil { | |
node.rollback(err) | |
} | |
} | |
}() | |
if err := node.F(); err != nil && node.rollback != nil { | |
node.rollback(err) | |
} | |
return nil | |
} | |
type TaskTree struct{} | |
func (t *TaskTree) Go(root *Node, rollback func(error)) func() <-chan error { | |
wg := &sync.WaitGroup{} | |
chErr := make(chan error, 1) | |
wg.Add(1) | |
waitFunc := func() <-chan error { | |
wg.Wait() | |
if len(chErr) == 0 { | |
chErr <- nil | |
} | |
return chErr | |
} | |
var n int32 | |
var rollbackCalled int32 | |
root.doNodes(wg, &n, nil, func(err error) { | |
if atomic.AddInt32(&rollbackCalled, 1) == 1 { | |
rollback(err) | |
chErr <- err | |
} | |
}) | |
return waitFunc | |
} | |
func (t *TaskTree) GoN(nodes []*Node, rollback func(error)) func() <-chan error { | |
wg := &sync.WaitGroup{} | |
chErr := make(chan error, 1) | |
waitFunc := func() <-chan error { | |
wg.Wait() | |
if len(chErr) == 0 { | |
chErr <- nil | |
} | |
return chErr | |
} | |
if len(nodes) == 0 { | |
return waitFunc | |
} | |
wg.Add(len(nodes)) | |
var n int32 | |
var rollbackCalled int32 | |
for _, v := range nodes { | |
v.doNodes(wg, &n, nil, func(err error) { | |
if atomic.AddInt32(&rollbackCalled, 1) == 1 { | |
rollback(err) | |
chErr <- err | |
} | |
}) | |
} | |
return waitFunc | |
} | |
func main() { | |
f1 := func() error { | |
time.Sleep(time.Second) | |
log.Println(111) | |
return nil | |
} | |
f2 := func() error { | |
time.Sleep(time.Second) | |
log.Println(222) | |
return nil | |
} | |
f3 := func() error { | |
time.Sleep(time.Second) | |
log.Println(333) | |
return nil //fmt.Errorf("[failed-333]") | |
} | |
f4 := func() error { | |
time.Sleep(time.Second) | |
log.Println(444) | |
return nil | |
} | |
commit := func() { | |
log.Println("commit") | |
} | |
rollback := func(err error) { | |
log.Println("rollback:", err) | |
} | |
task := &Node{ | |
F: f4, | |
Nodes: []*Node{ | |
{ | |
F: f3, | |
// 并发控制: | |
// 1. 先执行叶子节点,单个Node的所有叶子执行完之后执行Node自己、层层向上最终执行到根; | |
// 2. 每个叶子节点会创建一个协程; | |
// 3. 单个Node的所有叶子执行完后,其中最后执行完的叶子协程继续执行Node自己,其他叶子协程退出,一个任务树同时最多协程数量即为叶子数量。 | |
// 单就这个示例: | |
// 1. f1 与 f2 先各自一个协程去执行; | |
// 2. f1 和 f2 都执行完之后,它们之中先执行完的协程退出、后执行完的协程继续执行 f3; | |
// 3. f3 执行完后,该协程继续执行 f4 直到整个任务树执行完毕。 | |
Nodes: []*Node{ | |
{ | |
F: f1, | |
}, | |
{ | |
F: f2, | |
}, | |
}, | |
}, | |
}, | |
} | |
t := &TaskTree{} | |
chWait := t.Go(task, rollback) | |
if err := <-chWait(); err == nil { | |
commit() | |
} | |
} |
把 f3 或者其他 func 返回 err 来测试执行失败时回滚:
f3 := func() error {
time.Sleep(time.Second)
log.Println(333)
return fmt.Errorf("[failed-333]")
}
输出:
2023/05/05 00:20:29 111
2023/05/05 00:20:29 222
2023/05/05 00:20:30 333
2023/05/05 00:20:30 rollback: [failed-333]
2023/05/05 00:20:31 444
- 增加context
- 子任务失败不再向上调用父任务
package main
import (
"context"
"fmt"
"log"
"sync/atomic"
"time"
)
type Node struct {
F func(context.Context) error
Nodes []*Node
selfCnt int32
parentCnt *int32
parentDone chan struct{}
parentFunc func(context.Context) error
failFunc func(error)
}
func (node *Node) doNodes(ctx context.Context, parentCnt *int32, parentDone chan struct{}, parentFunc func(context.Context) error, failFunc func(error)) {
node.selfCnt = int32(len(node.Nodes))
node.parentDone = parentDone
node.parentCnt = parentCnt
node.parentFunc = parentFunc
node.failFunc = failFunc
if node.selfCnt > 0 {
for _, v := range node.Nodes {
v.doNodes(ctx, &node.selfCnt, nil, node.doSelf, failFunc)
}
return
}
go node.doSelf(ctx)
}
func (node *Node) doSelf(ctx context.Context) error {
var err error
defer func() {
if node.parentDone != nil {
select {
case node.parentDone <- struct{}{}:
default:
}
} else if err == nil && atomic.AddInt32(node.parentCnt, -1) == 0 && node.parentFunc != nil {
if err := node.parentFunc(ctx); err != nil && node.failFunc != nil {
node.failFunc(err)
}
}
}()
err = node.F(ctx)
if err != nil && node.failFunc != nil {
node.failFunc(err)
}
return nil
}
type TaskTree struct{}
func (t *TaskTree) Go(ctx context.Context, root *Node) func() <-chan error {
chErr := make(chan error, 1)
chDone := make(chan struct{}, 1)
waitFunc := func() <-chan error {
var err error
select {
case <-chDone:
case err = <-chErr:
}
select {
case chErr <- err:
default:
}
return chErr
}
var n int32
var failFuncCalled int32
root.doNodes(ctx, &n, chDone, nil, func(err error) {
if atomic.AddInt32(&failFuncCalled, 1) == 1 {
chErr <- err
}
})
return waitFunc
}
func main() {
f1 := func(ctx context.Context) error {
time.Sleep(time.Second)
select {
case <-ctx.Done():
log.Println(111, ctx.Err())
return ctx.Err()
default:
}
log.Println(111)
// return nil
return fmt.Errorf("[failed-111]")
}
f2 := func(ctx context.Context) error {
time.Sleep(time.Second * 2)
select {
case <-ctx.Done():
log.Println(222, ctx.Err())
return ctx.Err()
default:
}
log.Println(222)
return nil
// return fmt.Errorf("[failed-222]")
}
f3 := func(ctx context.Context) error {
time.Sleep(time.Second / 2)
select {
case <-ctx.Done():
log.Println(333, ctx.Err())
return ctx.Err()
default:
}
log.Println(333)
// return nil
return fmt.Errorf("[failed-333]")
}
f4 := func(ctx context.Context) error {
time.Sleep(time.Second)
select {
case <-ctx.Done():
log.Println(444, ctx.Err())
return ctx.Err()
default:
}
log.Println(444)
return nil
}
var err error
commit := func() {
log.Println("commit")
}
rollback := func() {
if err != nil {
log.Println("rollback")
}
}
task := &Node{
F: f4,
Nodes: []*Node{
{
F: f3,
// 并发控制:
// 1. 先执行叶子节点,单个Node的所有叶子执行完之后执行Node自己、层层向上最终执行到根;
// 2. 每个叶子节点会创建一个协程;
// 3. 单个Node的所有叶子执行完后,其中最后执行完的叶子协程继续执行Node自己,其他叶子协程退出,一个任务树同时最多协程数量即为叶子数量。
// 单就这个示例:
// 1. f1 与 f2 先各自一个协程去执行;
// 2. f1 和 f2 都执行完之后,它们之中先执行完的协程退出、后执行完的协程继续执行 f3;
// 3. f3 执行完后,该协程继续执行 f4 直到整个任务树执行完毕。
Nodes: []*Node{
{
F: f1,
},
{
F: f2,
},
},
},
},
}
func() {
defer rollback()
t := &TaskTree{}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
Wait := t.Go(ctx, task)
err = <-Wait()
if err == nil {
commit()
}
}()
for i := 0; i < 5; i++ {
time.Sleep(time.Second * 1)
log.Println("-----", i)
}
}
output:
2023/05/10 14:17:20 111
2023/05/10 14:17:20 rollback
2023/05/10 14:17:21 222 context canceled
2023/05/10 14:17:21 ----- 0
2023/05/10 14:17:22 ----- 1
2023/05/10 14:17:23 ----- 2
2023/05/10 14:17:24 ----- 3
2023/05/10 14:17:25 ----- 4
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
执行成功时的输出: