Source code for pynucastro.reduction.sensitivity_analysis
"""Methods and functions used for sensitivity analysis."""
import numpy as np
from pynucastro.mpi_utils import mpi_importer
from pynucastro.nucdata import Nucleus
MPI = mpi_importer()
[docs]
def binary_search_trim(network, nuclei, errfunc, *, thresh=0.05, args=None):
"""Given an array of nuclei sorted roughly by relative importance,
perform a binary search to trim out nuclei from the network until
the error is as close as possible to the given threshold without
exceeding it. Nuclei whose removal will result in only a small
increase in error need to be packed towards the back of the array
for the binary search to work effectively.
Parameters
----------
network : RateCollection
The network to reduce.
nuclei : Iterable(Nucleus) or Iterable(str)
Nuclei to consider for the final network, sorted by decreasing
importance (i.e. most important nuclei first). Importance can
be determined by something like the *drgep* function.
errfunc : Callable
Error function to use when evaluating error, with the signature
``error(net, *args)``, where ``net`` is the reduced network as
an argument and return the relative error produced by the reduction.
If ``use_mpi`` is ``False``, the error function can be parallelized
with MPI. Otherwise ``sens_analysis`` will be parallelized and the
error function should not be.
thresh : float
Threshold for acceptable error. Default is 0.05.
args : tuple
Additional arguments to pass through to the error function
Returns
-------
net : RateCollection
A reduced reaction network with an evaluated error
approximately equal to the supplied threshold.
"""
nuclei = Nucleus.cast_list(nuclei)
start_idx = 0
seg_size = len(nuclei) / 2
while seg_size >= 0.5:
# Divide up into segments
end_idx = start_idx + round(seg_size)
red_net = network.linking_nuclei(nuclei[:end_idx])
# Evaluate error
if args is None:
err = errfunc(red_net)
else:
err = errfunc(red_net, *args)
if err <= thresh:
seg_size /= 2
else:
start_idx += round(seg_size)
seg_size /= 2
return network.linking_nuclei(nuclei[:start_idx+1])
def _progress_bar(frac, size=50):
n = round(size*frac)
progress_bar = '[' + '⊙'*n + ' '*(size-n) + ']'
if frac < 1.0:
end = '\r'
else:
end = '\n'
print(progress_bar, f'{round(100*frac)}%', end=end)
[docs]
def sens_analysis(network, errfunc, *, thresh=0.05, args=None,
use_mpi=False, print_prog=True):
"""Given a reaction network, remove nuclei from the network
one-by-one until the induced error is as close to the given
threshold as possible without exceeding it. This will test nuclei
for removal individually and remove the one that induces the
smallest error on each pass. Since it requires O(n^2) error
function evaluations, this routine is much more expensive than
``binary_search``, but it will typically trim the network down
significantly more.
Parameters
----------
network : RateCollection
The network to reduce. Can be a RateCollection or a subclass.
errfunc : Callable
Error function to use when evaluating error, with the signature
``error(net, *args)``, where ``net`` is the reduced network as
an argument and return the relative error produced by the reduction.
If ``use_mpi`` is ``False``, the error function can be parallelized
with MPI. Otherwise ``sens_analysis`` will be parallelized and the
error function should not be.
thresh : float
Threshold for acceptable error. Default is 0.05.
args : tuple
Additional arguments to pass through to the error function
use_mpi : bool
Whether to parallelize the loop over nuclei with each pass or
not using MPI. For p MPI processes, the parallelized
function will require O(n^2/p) error function evaluations per
process. This option is ``False`` by default. If the error
function is parallelized using MPI, this option should be set
to ``False``.
print_prog : bool
Whether to print out the progress of the function as it runs
or not. Includes a progress bar for each pass and messages
indicating when the algorithm starts and ends.
Returns
-------
net : RateCollection
A reduced reaction network with an evaluated error
approximately equal to the supplied threshold.
"""
if use_mpi:
comm = MPI.COMM_WORLD
MPI_N = comm.Get_size()
MPI_rank = comm.Get_rank()
else:
MPI_N = 1
MPI_rank = 0
nuclei = list(network.unique_nuclei)
err = 0.0
nrem = 0
print_prog = print_prog and (MPI_rank == 0)
if print_prog:
print("Performing sensitivity analysis...")
while True:
err = float('inf')
for i in range(MPI_rank, len(nuclei), MPI_N):
if print_prog:
print(f"Pass {nrem+1}:", end=' ')
_progress_bar(i/len(nuclei))
nuc = nuclei.pop(i)
if args is None:
err_i = errfunc(network.linking_nuclei(nuclei, print_warning=False))
else:
err_i = errfunc(network.linking_nuclei(nuclei, print_warning=False), *args)
if err_i < err:
err = err_i
min_idx = i
nuclei.insert(i, nuc)
if use_mpi:
err = comm.gather(err, root=0)
min_idx = comm.gather(min_idx, root=0)
if MPI_rank == 0:
min_rank = np.argmin(err)
err = err[min_rank]
min_idx = min_idx[min_rank]
err = comm.bcast(err, root=0)
min_idx = comm.bcast(min_idx, root=0)
if print_prog:
print(f"Pass {nrem+1}:", end=' ')
_progress_bar(1.0)
if err <= thresh:
nuclei.pop(min_idx)
nrem += 1
else:
break
if print_prog:
print(f"Done. Removed {nrem} nuclei.")
return network.linking_nuclei(nuclei)