Source code for pynucastro.reduction.sensitivity_analysis

import numpy as np

from pynucastro.nucdata import Nucleus
from pynucastro.reduction.reduction_utils import mpi_importer

MPI = mpi_importer()


[docs] def binary_search_trim(network, nuclei, errfunc, thresh=0.05): """ Given an array of nuclei sorted roughly by relative importance, perform a binary search to trim out nuclei from the network until the error is as close as possible to the given threshold without exceeding it. Nuclei whose removal will result in only a small increase in error need to be packed towards the back of the array for the binary search to work effectively. :param network: The network to reduce. :param nuclei: Nuclei to consider for the final network, sorted by decreasing importance (i.e. most important nuclei first). Importance can be determined by something like the *drgep* function. :param errfunc: Error function to use when evaluating error. Should take a reduced network as an argument and return the relative error produced by the reduction. This can be a parallel (MPI) function. :param thresh: Threshold for acceptable error. Default is 0.05. :return: A reduced reaction network with an evaluated error approximately equal to the supplied threshold. """ nuclei = Nucleus.cast_list(nuclei) start_idx = 0 seg_size = len(nuclei) / 2 while seg_size >= 0.5: # Divide up into segments end_idx = start_idx + round(seg_size) red_net = network.linking_nuclei(nuclei[:end_idx]) # Evaluate error err = errfunc(red_net) if err <= thresh: seg_size /= 2 else: start_idx += round(seg_size) seg_size /= 2 return network.linking_nuclei(nuclei[:start_idx+1])
def _progress_bar(frac, size=50): n = round(size*frac) progress_bar = '[' + '⊙'*n + ' '*(size-n) + ']' if frac < 1.0: end = '\r' else: end = '\n' print(progress_bar, f'{round(100*frac)}%', end=end)
[docs] def sens_analysis(network, errfunc, thresh=0.05, use_mpi=False, print_prog=True): """ Given a reaction network, remove nuclei from the network one-by-one until the induced error is as close to the given threshold as possible without exceeding it. This will test nuclei for removal individually and remove the one that induces the smallest error on each pass. Since it requires O(n^2) error function evaluations, this routine is much more expensive than *binary_search*, but it will typically trim the network down significantly more. :param network: The network to reduce. :param errfunc: Error function to use when evaluating error. Should take a reduced network as an argument and return the relative error produced by the reduction. If *use_mpi* is *False*, the error function can be parallelized with MPI. Otherwise *sens_analysis* will be parallelized and the error function should not be. :param thresh: Threshold for acceptable error. Default is 0.05. :param use_mpi: Whether to parallelize the loop over nuclei with each pass or not using MPI. For *p* MPI processes, the parallelized function will require O(n^2/p) error function evaluations per process. This option is *False* by default. If the error function is parallelized using MPI, this option should be set to *False*. :param print_prog: Whether to print out the progress of the function as it runs or not. Includes a progress bar for each pass and messages indicating when the algorithm starts and ends. :return: A reduced reaction network with an evaluated error approximately equal to the supplied threshold. """ if use_mpi: comm = MPI.COMM_WORLD MPI_N = comm.Get_size() MPI_rank = comm.Get_rank() else: MPI_N = 1 MPI_rank = 0 nuclei = list(network.unique_nuclei) err = 0.0 nrem = 0 print_prog = print_prog and (MPI_rank == 0) if print_prog: print("Performing sensitivity analysis...") while True: err = float('inf') for i in range(MPI_rank, len(nuclei), MPI_N): if print_prog: print(f"Pass {nrem+1}:", end=' ') _progress_bar(i/len(nuclei)) nuc = nuclei.pop(i) err_i = errfunc(network.linking_nuclei(nuclei, print_warning=False)) if err_i < err: err = err_i min_idx = i nuclei.insert(i, nuc) if use_mpi: err = comm.gather(err, root=0) min_idx = comm.gather(min_idx, root=0) if MPI_rank == 0: min_rank = np.argmin(err) err = err[min_rank] min_idx = min_idx[min_rank] err = comm.bcast(err, root=0) min_idx = comm.bcast(min_idx, root=0) if print_prog: print(f"Pass {nrem+1}:", end=' ') _progress_bar(1.0) if err <= thresh: nuclei.pop(min_idx) nrem += 1 else: break if print_prog: print(f"Done. Removed {nrem} nuclei.") return network.linking_nuclei(nuclei)