Skip to content

Instantly share code, notes, and snippets.

@arkottke
Created May 3, 2022 22:17
Show Gist options
  • Save arkottke/b04720ea4d8c00c513266398ed12f312 to your computer and use it in GitHub Desktop.
Save arkottke/b04720ea4d8c00c513266398ed12f312 to your computer and use it in GitHub Desktop.
Python logic tree

Logic tree

Simple interface for defining a logic tree with a JSON file, and then interating over the branches. Branches with zero weight are excluded. Branches can be limited by requires or excludes. Additional parameters can be passed with dictionaries to params.

from __future__ import annotations
import json
import itertools
import numpy as np
import pytest
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union, Tuple
from pathlib import Path
AlternativeValueType = Union[str, int, float, Tuple[str]]
@dataclass
class Realization:
name: str
value: Union[str, float, int]
weight: float
params: Dict[str, Any]
@dataclass
class Alternative:
value: AlternativeValueType
weight: float = 1.0
requires: Dict[str, Any] = field(default_factory=dict)
excludes: Dict[str, Any] = field(default_factory=dict)
params: Dict[str, Any] = field(default_factory=dict)
def is_valid(self, branch):
def matches(ref, check):
if isinstance(ref, list):
ret = check in ref
elif isinstance(ref, float):
ret = np.isclose(ref, check)
else:
ret = ref == check
return ret
okay = True
if self.requires:
# Check that the required realizations are present
okay = all(matches(v, branch[k].value) for k, v in self.requires.items())
if okay and self.excludes:
# Check that the excludes realizations are _not_ present
okay &= not all(
matches(v, branch[k].value) for k, v in self.excludes.items()
)
return okay
@dataclass
class Node:
name: str
alts: List[Union[Alternative, AlternativeValueType]]
def __post_init__(self):
self.alts = [
a if isinstance(a, Alternative) else Alternative(a) for a in self.alts
]
def __len__(self):
return len(self.alts)
def __getitem__(self, index):
return self.alts[index]
def by_value(self, value):
for a in self.alts:
if (
isinstance(value, float) and np.isclose(a.value, value)
) or a.value == value:
return a
def __iter__(self):
for a in self.alts:
if a.weight > 0:
yield Realization(self.name, a.value, a.weight, a.params)
@property
def options(self):
return tuple(a.value for a in self.alts)
@classmethod
def from_dict(cls, d):
return cls(d["name"], [Alternative(**a) for a in d["alts"]])
@dataclass
class Branch:
params: Dict[str, Realization]
def __getitem__(self, key):
return self.params[key]
def __iter__(self):
yield from self.params.values()
@property
def weight(self):
return np.product([p.weight for p in self])
def value(self, key):
return self.params[key].value
def as_dict(self):
return {k: a.value for k, a in self.params.items()}
@dataclass
class LogicTree:
nodes: List[Node]
def __iter__(self) -> Branch:
for reals in itertools.product(*self.nodes):
branch = Branch({r.name: r for r in reals})
if self.is_valid(branch):
yield branch
def is_valid(self, branch):
for param in branch.params.values():
# Select the alternative on the logic tree
alt = self[param.name].by_value(param.value)
if not alt.is_valid(branch):
# print("Skipping:", "-".join(r.value for r in branch))
return False
return True
def __getitem__(self, key):
for n in self.nodes:
if n.name == key:
return n
@classmethod
def from_list(cls, dicts: List[Dict[str, Any]]) -> LogicTree:
nodes = [Node.from_dict(d) for d in dicts]
return cls(nodes)
@pytest.fixture
def my_tree():
tree = LogicTree(
[
Node("foo", "ab"),
Node(
"bar",
[
Alternative("c"),
Alternative("d"),
Alternative("e", requires={"foo": "a"}),
],
),
Node(
"baz",
[
Alternative("f", requires={"bar": ["c", "d"]}),
Alternative("g"),
Alternative("h", excludes={"foo": "a"}),
],
),
]
)
return tree
def test_parse_json():
lt = LogicTree.from_json(Path(__file__).parent / "data/test_logic_tree.json")
def test_node_init():
node = Node("foo", ["a", "b"])
assert isinstance(node[0], Alternative)
def test_branch_count(my_tree):
branches = list(my_tree)
count = (2 * 3 * 3) - 3 - 1 - 3
assert len(branches) == count
def test_valid_branches(my_tree):
branches = list(my_tree)
def is_branch(values):
for b in branches:
if all(b[k].value == v for k, v in values.items()):
return True
else:
return False
assert is_branch({"foo": "a", "bar": "c", "baz": "f"})
assert not is_branch({"foo": "a", "bar": "e", "baz": "f"})
assert not is_branch({"foo": "a", "bar": "d", "baz": "h"})
if __name__ == "__main__":
fpath = Path(__file__).parent / "data/test_logic_tree.json"
lt = LogicTree.from_list(json.load(fpath.open()))
print(len(list(lt)))
[
{
"name": "soil_thick",
"alts": [
{
"value": 0,
"weight": 1,
"params": {
"outputs": [
{
"depth": 0,
"wavefield": "outcrop"
}
]
}
},
{
"value": 4.6,
"weight": 1,
"params": {
"outputs": [
{
"depth": 0,
"wavefield": "outcrop"
}
]
}
},
{
"value": 12.2,
"weight": 1,
"params": {
"outputs": [
{
"depth": 0,
"wavefield": "outcrop"
},
{
"depth": 13.7,
"wavefield": "outcrop"
}
]
}
}
]
},
{
"name": "method",
"alts": [
{
"value": "surface wave",
"weight": 0.5
},
{
"value": "borehole",
"weight": 0.5
}
]
},
{
"name": "vel_source",
"alts": [
{
"value": "masw & mam",
"weight": 0.6,
"requires": {
"method": "surface wave"
}
},
{
"value": "sasw",
"weight": 0.4,
"requires": {
"method": "surface wave"
}
},
{
"value": "litho",
"weight": 1.0,
"requires": {
"method": "borehole"
}
}
]
},
{
"name": "shallow_source",
"alts": [
{
"value": "R0-LR1.5",
"weight": 0.25,
"requires": {
"vel_source": "masw & mam"
}
},
{
"value": "R0-LR2.0",
"weight": 0.25,
"requires": {
"vel_source": "masw & mam"
}
},
{
"value": "R0-LR2.5",
"weight": 0.25,
"requires": {
"vel_source": "masw & mam"
}
},
{
"value": "R0-LR3.0",
"weight": 0.25,
"requires": {
"vel_source": "masw & mam"
}
},
{
"value": "Array 1",
"weight": 0.4,
"requires": {
"vel_source": "sasw"
}
},
{
"value": "Array 3",
"weight": 0.4,
"requires": {
"vel_source": "sasw"
}
},
{
"value": "Array 4",
"weight": 0.2,
"requires": {
"vel_source": "sasw"
}
},
{
"value": "w/o thin interbeds",
"weight": 0.3,
"requires": {
"vel_source": "litho"
}
},
{
"value": "w/ thin interbeds",
"weight": 0.7,
"requires": {
"vel_source": "litho"
}
}
]
},
{
"name": "deep_source",
"alts": [
{
"value": "INEL-1",
"weight": 1
}
]
},
{
"name": "deep_adj",
"alts": [
{
"value": "lower",
"weight": 0.185
},
{
"value": "center",
"weight": 0.630
},
{
"value": "upper",
"weight": 0.185
}
]
},
{
"name": "site_atten",
"alts": [
{
"value": 0.0212,
"weight": 1
},
{
"value": 0.0308,
"weight": 1
},
{
"value": 0.0401,
"weight": 1
},
{
"value": 0.0510,
"weight": 1
},
{
"value": 0.0648,
"weight": 1
},
{
"value": 0.0846,
"weight": 1
},
{
"value": 0.1225,
"weight": 1
}
]
},
{
"name": "mrd_curves",
"alts": [
{
"value": "site specific",
"weight": 0.7
},
{
"value": "darendeli",
"weight": 0.3
}
]
}
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment