Skip to content

Instantly share code, notes, and snippets.

@yangfch3
Created December 7, 2021 10:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yangfch3/654240f75e2df33d9632ba4acd82051b to your computer and use it in GitHub Desktop.
Save yangfch3/654240f75e2df33d9632ba4acd82051b to your computer and use it in GitHub Desktop.
Lua 决策树
-- @author: yangfch3
-- @date: 2020/09/27 15:20
------------------------------------------
-- Node 基类
local DTNode = BaseClass("DTNode")
function DTNode:ctor() end
function DTNode:Eval(decisionTree, context) end
function DTNode:SetChildren()
assert(false)
end
-- 条件结点
local DTCondNode = BaseClass("DTCondNode", DTNode)
function DTCondNode:ctor(cond)
self._cond = cond
end
function DTCondNode:SetChildren(...)
self._childrenNodes = {...}
end
function DTCondNode:Eval(decisionTree, context)
local nodeIdx = self._cond(decisionTree, context)
assert(nodeIdx and nodeIdx > 0 and nodeIdx <= #self._childrenNodes)
return self._childrenNodes[nodeIdx]:Eval(decisionTree, context)
end
-- 布尔结点
local DTBoolNode = BaseClass("DTBoolNode", DTNode)
function DTBoolNode:ctor(cond)
self._cond = cond
end
function DTBoolNode:SetChildren(trueNode, falseNode)
self._trueNode = trueNode
self._falseNode = falseNode
assert(trueNode)
assert(falseNode)
end
function DTBoolNode:Eval(decisionTree, context)
local b = self._cond(decisionTree, context)
local node = b and self._trueNode or self._falseNode
return node:Eval(decisionTree, context)
end
-- 任务结点:只逻辑计算,不返回
local DTTaskNode = BaseClass("DTTaskNode", DTNode)
function DTTaskNode:ctor(func)
self._func = func
end
function DTTaskNode:SetChildren(nextNode)
self._nextNode = nextNode
assert(nextNode)
end
function DTTaskNode:Eval(decisionTree, context)
self._func(decisionTree, context)
return self._nextNode:Eval(decisionTree, context)
end
-- 执行结点:决策树的终点节点
local DTExeNode = BaseClass("DTExeNode", DTNode)
function DTExeNode:ctor(func)
self._func = func
end
function DTExeNode:Eval(decisionTree, context)
return self._func(decisionTree, context)
end
-- 手动决策树
local ManualDecisionTree = BaseClass("ManualDecisionTree")
function ManualDecisionTree:ctor(rootNode)
self:SetRootNode(rootNode)
end
function ManualDecisionTree:SetRootNode(node)
self._rootNode = node
end
function ManualDecisionTree:Eval(context)
return self._rootNode:Eval(self, context)
end
return {
DTCondNode = DTCondNode,
DTBoolNode = DTBoolNode,
DTTaskNode = DTTaskNode,
DTExeNode = DTExeNode,
ManualDecisionTree = ManualDecisionTree
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment