Skip to content

Instantly share code, notes, and snippets.

@shimwell
Last active November 22, 2022 18:31
Show Gist options
  • Save shimwell/b14a674dff0cf94be79d69610f50be69 to your computer and use it in GitHub Desktop.
Save shimwell/b14a674dff0cf94be79d69610f50be69 to your computer and use it in GitHub Desktop.
minimal_magic_from_ps
#!/usr/bin/env python
from copy import deepcopy
import numpy as np
import openmc
import openmc.lib
from openmc.mpi import comm
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
_ALLOWED_FILTER_TYPES = (openmc.MeshFilter, openmc.EnergyFilter, openmc.ParticleFilter)
def magic(model, tally_id, iterations, rel_err_tol=0.7):
"""
Performs weight window generation using the MAGIC method
Davis, A., & Turner, A. (2011). Comparison of global variance reduction
techniques for Monte Carlo radiation transport simulations of ITER. Fusion
Engineering and Design, 86, 2698–2700.
https://doi.org/10.1016/j.fusengdes.2011.01.059
Parameters
----------
model : openmc.Model
The OpenMC model to run
tally : int
The tally ID to use for weight window generation
iterations : int
The number of iterations to perform
rel_err_tol : float (default: 0.7)
Upper limit on relative error of flux values used to produce
weight windows.
"""
check_tally(model, tally_id)
if comm.rank == 0:
model.export_to_xml()
comm.barrier()
sum_of_std_dev = []
for _ in range(iterations):
openmc.run()
sp_file = f'statepoint.{model.settings.batches}.h5'
if comm.rank == 0:
wws = generate_wws(sp_file, tally_id, rel_err_tol)
model.settings.weight_windows = wws
model.export_to_xml()
plot_flux_with_ww(wws, model, sp_file, tally_id, filename=f'flux_std_dev_{_}.png')
sum_of_std_dev.append(get_std_dev_sum(sp_file, tally_id))
print(sum_of_std_dev)
def get_std_dev_sum(sp_file, tally_id):
with openmc.StatePoint(sp_file) as sp:
flux_tally = sp.get_tally(id=tally_id)
return flux_tally.std_dev.sum()
def check_tally(model, tally_id):
for tally in model.tallies:
if tally.id == tally_id:
break
if tally.id != tally_id:
raise RuntimeError(f'No tally with ID "{tally_id}" is present in the model.')
filter_types = tuple(type(f) for f in tally.filters)
if openmc.MeshFilter not in filter_types:
raise ValueError('This script requires a MeshFilter on the specified tally')
if filter_types != _ALLOWED_FILTER_TYPES[:len(filter_types)]:
raise RuntimeError(f'This script accepts the following types: {_ALLOWED_FILTER_TYPES}\n in that order. '
'Only the MeshFilter is required.')
def generate_wws(sp_file, tally_id, rel_err_tol):
"""
Generates weight windows based on a tally.
Returns
-------
Iterable of openmc.WeightWindows
"""
with openmc.StatePoint(sp_file) as sp:
tally = sp.get_tally(id=tally_id)
filter_types = [type(f) for f in tally.filters]
for ft in filter_types:
if ft not in _ALLOWED_FILTER_TYPES:
raise ValueError(f'Filter type {ft} is unsupported for weight window generation')
mesh_filter = tally.find_filter(openmc.MeshFilter)
mesh = mesh_filter.mesh
mesh_copy = deepcopy(mesh)
# get the tally mean and relative error
mean = tally.get_reshaped_data()
rel_err = tally.get_reshaped_data(value='rel_err')
# in case other scores are applied to this tally,
# make sure to use the correct index for "flux"
score_idx = tally.get_score_index("flux")
mean = mean[..., score_idx]
rel_err = rel_err[..., score_idx]
# in case other nuclides are applied to this tally,
# make sure to use the 'total' nuclide entry
nuclide_idx = tally.get_nuclide_index("total")
mean = mean[..., nuclide_idx]
rel_err = rel_err[..., nuclide_idx]
# sanity check: number of dimensions should now be no more than three
assert mean.ndim <= 3
# make sure there are three dimensions
if openmc.EnergyFilter not in filter_types:
mean = np.expand_dims(mean, 1)
rel_err = np.expand_dims(rel_err, 1)
if openmc.ParticleFilter not in filter_types:
mean = np.expand_dims(mean, 2)
rel_err = np.expand_dims(rel_err, 2)
assert mean.ndim == 3
if openmc.EnergyFilter not in filter_types:
n_e_bins = 1
e_bounds = [0, 1e40]
else:
e_filter = tally.find_filter(openmc.EnergyFilter)
e_bounds = e_filter.values
n_e_bins = e_filter.num_bins
if openmc.ParticleFilter not in filter_types:
particles = ['neutron']
else:
p_filter = tally.find_filter(openmc.ParticleFilter)
particles = p_filter.bins
wws = []
mean = mean.T
rel_err = rel_err.T
# loop over particle data
for particle, p_mean, p_rel_err in zip(particles, mean, rel_err):
ww_lower_bounds = np.empty((*mesh.dimension, n_e_bins), dtype=float)
for i, (e_mean, e_rel_err) in enumerate(zip(p_mean, p_rel_err)):
# now we should be working with mesh data
e_mean = e_mean / np.max(e_mean)
e_mean[(e_mean == 0) | (e_rel_err > rel_err_tol)] = -1.0
e_mean[~np.isfinite(e_mean)] = -1.0
e_mean = e_mean.reshape(mesh.dimension[::-1]).T
ww_lower_bounds[..., i] = e_mean
p_weight_windows = openmc.WeightWindows(
mesh_copy,
ww_lower_bounds,
upper_bound_ratio=5.0,
energy_bounds=e_bounds,
particle_type=particle
)
wws.append(p_weight_windows)
return wws
def create_model():
steel = openmc.Material(name='Steel')
steel.set_density('g/cc', 8.0)
steel.add_element('Si', 0.010048)
steel.add_element('S', 0.00023)
steel.add_element('Fe', 0.669)
steel.add_element('Ni', 0.12)
steel.add_element('Mo', 0.025)
steel.add_nuclide('P31',0.00023)
steel.add_nuclide('Mn55',0.011014)
sphere1 = openmc.Sphere(r=300, boundary_type='vacuum')
region1 = -sphere1
cell1 = openmc.Cell(fill=steel, region=region1)
model = openmc.Model()
model.geometry = openmc.Geometry([cell1])
# # plt.figure(figsize=(10,10))
# # plt.savefig('geometry.png')
# model.geometry.root_universe.plot(width=[2*radii[-1]]*2, pixels=(600, 600), color_by='material', colors=mat_colors)
# # plt.show()
space = openmc.stats.Point((0.0, 0.0, 0.0))
angle = openmc.stats.Isotropic()
energy = openmc.stats.Discrete([1.0e4], [1.0])
source = openmc.Source(space=space, angle=angle, energy=energy)
source.particle = 'neutron'
model.settings.run_mode = 'fixed source'
model.settings.source = source
model.settings.particles = 1000
model.settings.batches = 5
# model.settings.max_splits = 1000 does not appear to change anything
return model
def create_mesh(geometry, dimension=(80, 80, 80)):
mesh = openmc.RegularMesh()
mesh.lower_left = geometry.bounding_box[0]
mesh.upper_right = geometry.bounding_box[1]
mesh.dimension = dimension
return mesh
def create_tallies(mesh):
mesh_filter = openmc.MeshFilter(mesh)
flux_tally = openmc.Tally(name='flux tally')
flux_tally.filters = [mesh_filter]
flux_tally.scores = ['flux']
return flux_tally
def plot_flux_with_ww(wws, model, sp_file, tally_id, filename='flux_std_dev.png', slice_index=40):
with openmc.StatePoint(sp_file) as sp:
flux_tally = sp.get_tally(id=tally_id)
llc, urc = model.geometry.bounding_box
fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(16,4))
fig.suptitle('Magic method of variance reduction with OpenMC')
# fig.tight_layout()
# create a plot of the mean values
flux_mean = flux_tally.mean.reshape(*mesh.dimension)
img1 = ax1.imshow(flux_mean[slice_index], origin='lower', extent=(llc[0], urc[0], llc[1], urc[1]), norm=LogNorm())
ax1.set_title('Flux Mean')
plt.colorbar(img1, ax=ax1, fraction=0.046)
img1.set_clim(vmin=1e-30, vmax=1.0)
# create a plot of the flux relative error
flux_rel_err = flux_tally.get_values(value='rel_err').reshape(*mesh.dimension)
img2 = ax2.imshow(flux_rel_err[slice_index], origin='lower', extent=(llc[0], urc[0], llc[1], urc[1]))
ax2.set_title('Flux Rel. Err.')
plt.colorbar(img2, ax=ax2, fraction=0.046)
# ax2.set_colorbar(img2, ax=ax2)
img2.set_clim(vmin=0.0, vmax=1.0)
# create a plot of the ww lower
# wws is a list
print('*mesh.dimension', *mesh.dimension)
# print(wws)
# print(len(wws[0]))
# wws_reshaped = np.array(wws).reshape(80,80)
# wws_reshaped = np.reshape(wws, (*mesh.dimension)).T
img3 = ax3.imshow(wws[0].lower_ww_bounds[slice_index], origin='lower', extent=(llc[0], urc[0], llc[1], urc[1]), norm=LogNorm())
ax3.set_title('lower_ww_bounds')
plt.colorbar(img3, ax=ax3, fraction=0.046)
# ax2.set_colorbar(img2, ax=ax2)
# img3.set_clim(vmin=0.0, vmax=1.0)
plt.savefig(filename, dpi=400)
model = create_model()
mesh = create_mesh(model.geometry)
flux_tally = create_tallies(mesh)
model.tallies = [flux_tally]
magic(model, flux_tally.id, 16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment