Skip to content

Instantly share code, notes, and snippets.

@lesismal
Created May 4, 2023 16:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lesismal/6b397b12bb1e328395873a2c35a71af0 to your computer and use it in GitHub Desktop.
Save lesismal/6b397b12bb1e328395873a2c35a71af0 to your computer and use it in GitHub Desktop.
TaskTree.go
package main
import (
"log"
"sync"
"sync/atomic"
"time"
)
type Node struct {
F func() error
Nodes []*Node
selfCnt int32
parentCnt *int32
parentWg *sync.WaitGroup
parentFunc func() error
rollback func(error)
}
func (node *Node) doNodes(parentWg *sync.WaitGroup, parentCnt *int32, parentFunc func() error, rollback func(error)) {
node.selfCnt = int32(len(node.Nodes))
node.parentWg = parentWg
node.parentCnt = parentCnt
node.parentFunc = parentFunc
node.rollback = rollback
if node.selfCnt > 0 {
for _, v := range node.Nodes {
v.doNodes(nil, &node.selfCnt, node.doSelf, rollback)
}
return
}
go node.doSelf()
}
func (node *Node) doSelf() error {
defer func() {
if node.parentWg != nil {
node.parentWg.Done()
} else if atomic.AddInt32(node.parentCnt, -1) == 0 && node.parentFunc != nil {
if err := node.parentFunc(); err != nil && node.rollback != nil {
node.rollback(err)
}
}
}()
if err := node.F(); err != nil && node.rollback != nil {
node.rollback(err)
}
return nil
}
type TaskTree struct{}
func (t *TaskTree) Go(root *Node, rollback func(error)) func() <-chan error {
wg := &sync.WaitGroup{}
chErr := make(chan error, 1)
wg.Add(1)
waitFunc := func() <-chan error {
wg.Wait()
if len(chErr) == 0 {
chErr <- nil
}
return chErr
}
var n int32
var rollbackCalled int32
root.doNodes(wg, &n, nil, func(err error) {
if atomic.AddInt32(&rollbackCalled, 1) == 1 {
rollback(err)
chErr <- err
}
})
return waitFunc
}
func (t *TaskTree) GoN(nodes []*Node, rollback func(error)) func() <-chan error {
wg := &sync.WaitGroup{}
chErr := make(chan error, 1)
waitFunc := func() <-chan error {
wg.Wait()
if len(chErr) == 0 {
chErr <- nil
}
return chErr
}
if len(nodes) == 0 {
return waitFunc
}
wg.Add(len(nodes))
var n int32
var rollbackCalled int32
for _, v := range nodes {
v.doNodes(wg, &n, nil, func(err error) {
if atomic.AddInt32(&rollbackCalled, 1) == 1 {
rollback(err)
chErr <- err
}
})
}
return waitFunc
}
func main() {
f1 := func() error {
time.Sleep(time.Second)
log.Println(111)
return nil
}
f2 := func() error {
time.Sleep(time.Second)
log.Println(222)
return nil
}
f3 := func() error {
time.Sleep(time.Second)
log.Println(333)
return nil //fmt.Errorf("[failed-333]")
}
f4 := func() error {
time.Sleep(time.Second)
log.Println(444)
return nil
}
commit := func() {
log.Println("commit")
}
rollback := func(err error) {
log.Println("rollback:", err)
}
task := &Node{
F: f4,
Nodes: []*Node{
{
F: f3,
// 并发控制:
// 1. 先执行叶子节点,单个Node的所有叶子执行完之后执行Node自己、层层向上最终执行到根;
// 2. 每个叶子节点会创建一个协程;
// 3. 单个Node的所有叶子执行完后,其中最后执行完的叶子协程继续执行Node自己,其他叶子协程退出,一个任务树同时最多协程数量即为叶子数量。
// 单就这个示例:
// 1. f1 与 f2 先各自一个协程去执行;
// 2. f1 和 f2 都执行完之后,它们之中先执行完的协程退出、后执行完的协程继续执行 f3;
// 3. f3 执行完后,该协程继续执行 f4 直到整个任务树执行完毕。
Nodes: []*Node{
{
F: f1,
},
{
F: f2,
},
},
},
},
}
t := &TaskTree{}
chWait := t.Go(task, rollback)
if err := <-chWait(); err == nil {
commit()
}
}
@lesismal
Copy link
Author

lesismal commented May 4, 2023

执行成功时的输出:

2023/05/05 00:18:59 111
2023/05/05 00:18:59 222
2023/05/05 00:19:00 333
2023/05/05 00:19:01 444
2023/05/05 00:19:01 commit

@lesismal
Copy link
Author

lesismal commented May 4, 2023

把 f3 或者其他 func 返回 err 来测试执行失败时回滚:

f3 := func() error {
	time.Sleep(time.Second)
	log.Println(333)
	return fmt.Errorf("[failed-333]")
}

输出:

2023/05/05 00:20:29 111
2023/05/05 00:20:29 222
2023/05/05 00:20:30 333
2023/05/05 00:20:30 rollback: [failed-333]
2023/05/05 00:20:31 444

@lesismal
Copy link
Author

  1. 增加context
  2. 子任务失败不再向上调用父任务
package main

import (
	"context"
	"fmt"
	"log"
	"sync/atomic"
	"time"
)

type Node struct {
	F     func(context.Context) error
	Nodes []*Node

	selfCnt    int32
	parentCnt  *int32
	parentDone chan struct{}
	parentFunc func(context.Context) error
	failFunc   func(error)
}

func (node *Node) doNodes(ctx context.Context, parentCnt *int32, parentDone chan struct{}, parentFunc func(context.Context) error, failFunc func(error)) {
	node.selfCnt = int32(len(node.Nodes))
	node.parentDone = parentDone
	node.parentCnt = parentCnt
	node.parentFunc = parentFunc
	node.failFunc = failFunc
	if node.selfCnt > 0 {
		for _, v := range node.Nodes {
			v.doNodes(ctx, &node.selfCnt, nil, node.doSelf, failFunc)
		}
		return
	}
	go node.doSelf(ctx)
}

func (node *Node) doSelf(ctx context.Context) error {
	var err error
	defer func() {
		if node.parentDone != nil {
			select {
			case node.parentDone <- struct{}{}:
			default:
			}
		} else if err == nil && atomic.AddInt32(node.parentCnt, -1) == 0 && node.parentFunc != nil {
			if err := node.parentFunc(ctx); err != nil && node.failFunc != nil {
				node.failFunc(err)
			}
		}
	}()
	err = node.F(ctx)
	if err != nil && node.failFunc != nil {
		node.failFunc(err)
	}
	return nil
}

type TaskTree struct{}

func (t *TaskTree) Go(ctx context.Context, root *Node) func() <-chan error {
	chErr := make(chan error, 1)
	chDone := make(chan struct{}, 1)
	waitFunc := func() <-chan error {
		var err error
		select {
		case <-chDone:
		case err = <-chErr:
		}
		select {
		case chErr <- err:
		default:
		}
		return chErr
	}
	var n int32
	var failFuncCalled int32
	root.doNodes(ctx, &n, chDone, nil, func(err error) {
		if atomic.AddInt32(&failFuncCalled, 1) == 1 {
			chErr <- err
		}
	})
	return waitFunc
}

func main() {
	f1 := func(ctx context.Context) error {
		time.Sleep(time.Second)
		select {
		case <-ctx.Done():
			log.Println(111, ctx.Err())
			return ctx.Err()
		default:
		}
		log.Println(111)
		// return nil
		return fmt.Errorf("[failed-111]")
	}
	f2 := func(ctx context.Context) error {
		time.Sleep(time.Second * 2)
		select {
		case <-ctx.Done():
			log.Println(222, ctx.Err())
			return ctx.Err()
		default:
		}
		log.Println(222)
		return nil
		// return fmt.Errorf("[failed-222]")
	}
	f3 := func(ctx context.Context) error {
		time.Sleep(time.Second / 2)
		select {
		case <-ctx.Done():
			log.Println(333, ctx.Err())
			return ctx.Err()
		default:
		}
		log.Println(333)
		// return nil
		return fmt.Errorf("[failed-333]")
	}
	f4 := func(ctx context.Context) error {
		time.Sleep(time.Second)
		select {
		case <-ctx.Done():
			log.Println(444, ctx.Err())
			return ctx.Err()
		default:
		}
		log.Println(444)
		return nil
	}

	var err error
	commit := func() {
		log.Println("commit")
	}
	rollback := func() {
		if err != nil {
			log.Println("rollback")
		}
	}

	task := &Node{
		F: f4,
		Nodes: []*Node{
			{
				F: f3,
				// 并发控制:
				// 1. 先执行叶子节点,单个Node的所有叶子执行完之后执行Node自己、层层向上最终执行到根;
				// 2. 每个叶子节点会创建一个协程;
				// 3. 单个Node的所有叶子执行完后,其中最后执行完的叶子协程继续执行Node自己,其他叶子协程退出,一个任务树同时最多协程数量即为叶子数量。
				// 单就这个示例:
				// 1. f1 与 f2 先各自一个协程去执行;
				// 2. f1 和 f2 都执行完之后,它们之中先执行完的协程退出、后执行完的协程继续执行 f3;
				// 3. f3 执行完后,该协程继续执行 f4 直到整个任务树执行完毕。
				Nodes: []*Node{
					{
						F: f1,
					},
					{
						F: f2,
					},
				},
			},
		},
	}

	func() {
		defer rollback()

		t := &TaskTree{}
		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
		defer cancel()
		Wait := t.Go(ctx, task)

		err = <-Wait()
		if err == nil {
			commit()
		}
	}()

	for i := 0; i < 5; i++ {
		time.Sleep(time.Second * 1)
		log.Println("-----", i)
	}
}

output:

2023/05/10 14:17:20 111
2023/05/10 14:17:20 rollback
2023/05/10 14:17:21 222 context canceled
2023/05/10 14:17:21 ----- 0
2023/05/10 14:17:22 ----- 1
2023/05/10 14:17:23 ----- 2
2023/05/10 14:17:24 ----- 3
2023/05/10 14:17:25 ----- 4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment