Skip to content

Instantly share code, notes, and snippets.

@hemebond
Created August 29, 2020 12:34
Show Gist options
  • Save hemebond/7ec81a9d202437c0fb7919be389f892e to your computer and use it in GitHub Desktop.
Save hemebond/7ec81a9d202437c0fb7919be389f892e to your computer and use it in GitHub Desktop.
A custom YAML contructor that does a deep merge of dicts
import yaml
from copy import copy, deepcopy
from yaml.nodes import MappingNode
from yaml.loader import Loader
from yaml.constructor import SafeConstructor
# Copyright Ferry Boender, released under the MIT license.
def deepupdate(tgt, src):
"""Deep update target dict with src
For each k,v in src: if k doesn't exist in target, it is deep copied from
src to target. Otherwise, if v is a list, target[k] is extended with
src[k]. If v is a set, target[k] is updated with v, If v is a dict,
recursively deep-update it.
Examples:
>>> t = {'name': 'Ferry', 'hobbies': ['programming', 'sci-fi']}
>>> deepupdate(t, {'hobbies': ['gaming']})
>>> print t
{'name': 'Ferry', 'hobbies': ['programming', 'sci-fi', 'gaming']}
"""
target = deepcopy(tgt)
for k, v in src.items():
if type(v) == list:
if not k in target:
target[k] = deepcopy(v)
else:
target[k].extend(v)
elif type(v) == dict:
if not k in target:
target[k] = deepcopy(v)
else:
deepupdate(target[k], v)
elif type(v) == set:
if not k in target:
target[k] = v.copy()
else:
target[k].update(v.copy())
else:
target[k] = copy(v)
return target
class Constructor(SafeConstructor):
"""
Customise the mapping constructor to do a deep merge instead
of the regular shallow merge
"""
def construct_mapping(self, node, deep=False):
if isinstance(node, MappingNode):
self.flatten_mapping(node)
if not isinstance(node, MappingNode):
raise ConstructorError(None, None,
"expected a mapping node, but found %s" % node.id,
node.start_mark)
mapping = {}
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
try:
hash(key)
except TypeError as exc:
raise ConstructorError("while constructing a mapping", node.start_mark,
"found unacceptable key (%s)" % exc, key_node.start_mark)
value = self.construct_object(value_node, deep=True)
if key in mapping:
# Do a deep merge
if isinstance(value, dict) and isinstance(mapping[key], dict):
mapping[key] = deepupdate(mapping[key], value)
else:
mapping[key] = value
else:
mapping[key] = value
return mapping
class CustomLoader(Loader, Constructor):
pass
print(yaml.load(open('sample.yaml', 'r'), Loader=CustomLoader))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment