Source code for pynucastro.reduction.reduction

#!/usr/bin/env python3

import argparse
import sys
import time
from collections import namedtuple

import numpy as np

from pynucastro import Composition, Nucleus
from pynucastro.constants import constants
from pynucastro.reduction import mpi_importer, sens_analysis
from pynucastro.reduction.drgep import drgep
from pynucastro.reduction.generate_data import dataset
from pynucastro.reduction.load_network import load_network

MPI = mpi_importer()


def _wrap_conds(conds):
    """Return [conds] if conds has 1 dimension, and give back conds otherwise."""

    try:
        _ = conds[0]
        try:
            _ = conds[0][0]
            return conds
        except IndexError:
            return [conds]
    except IndexError:
        raise ValueError('Conditions must be non-empty subscriptable object') from None


NetInfo = namedtuple("NetInfo", "y ydot z a ebind m")


[docs] def get_net_info(net, comp, rho, T): y_dict = comp.get_molar() # Can alternatively use the NumPy-based method (evaluate_ydots_arr) ydots_dict = net.evaluate_ydots(rho, T, comp) y = np.zeros(len(net.unique_nuclei), dtype=np.float64) ydot = np.zeros(len(net.unique_nuclei), dtype=np.float64) z = np.zeros(len(net.unique_nuclei), dtype=np.float64) a = np.zeros(len(net.unique_nuclei), dtype=np.float64) ebind = np.zeros(len(net.unique_nuclei), dtype=np.float64) m = np.zeros(len(net.unique_nuclei), dtype=np.float64) for i, n in enumerate(net.unique_nuclei): y[i] = y_dict[n] ydot[i] = ydots_dict[n] z[i] = n.Z a[i] = n.A ebind[i] = n.nucbind or 0.0 m[i] = n.A_nuc * constants.m_u return NetInfo(y, ydot, z, a, ebind, m)
[docs] def enuc_dot(net_info): """Calculate the nuclear energy generation rate.""" return -np.sum(net_info.ydot * net_info.m) * constants.N_A * constants.c_light**2
[docs] def ye_dot(net_info): """Calculate the time rate of change of electron fraction.""" y, ydot, z, a, _, _ = net_info norm_fac = np.sum(y * a) return np.sum(ydot * z) / norm_fac - np.sum(y * z) / norm_fac**2 * np.sum(ydot * a)
[docs] def abar_dot(net_info): """Calculate the time rate of change of mean molecular weight.""" abar_inv = np.sum(net_info.y) return -1 / abar_inv**2 * np.sum(net_info.ydot)
[docs] def map_comp(comp, net): """Create new composition object with nuclei in net, and copy their mass fractions over.""" comp_new = Composition(net.unique_nuclei) for nuc in comp_new.X: comp_new.X[nuc] = comp.X[nuc] return comp_new
[docs] def rel_err(x, x0): """Compute the relative error between two NumPy arrays.""" return np.abs((x - x0) / x0)
[docs] def get_errfunc_enuc(net_old, conds): """Function for computing error in nuclear energy generation.""" enucdot_list = [] for comp, rho, T in conds: net_info_old = get_net_info(net_old, comp, rho, T) enucdot_list.append(enuc_dot(net_info_old)) def erf(net_new): err = 0.0 for cond, enucdot_old in zip(conds, enucdot_list): net_info_new = get_net_info(net_new, *cond) enucdot_new = enuc_dot(net_info_new) err = max(err, rel_err(enucdot_new, enucdot_old)) return err return erf
[docs] def main(): #----------------------------------- # Setup parser and process arguments #----------------------------------- description = "Example/test script for running a selected reduction algorithm." algorithm_help = "The algorithm to use. Currently only supports 'drgep'." endpoint_help = "The nucleus to use as an endpoint in the unreduced network." datadim_help = """The dimensions of the dataset to generate. Ordering is the number of densities, then temperatures, then metallicities. Default is [4, 4, 4].""" drho_help = "Number of density points to include in the dataset." dtemp_help = "Number of temperature points to include in the dataset." dmetal_help = "Number of metallicity points to include in the dataset." brho_help = "Range of densities to include in the dataset (in g/cm^3)." btemp_help = "Range of temperatures to include in the dataset (in K)." bmetal_help = """Range of metallicities to include in the dataset. Half of the metallicity will be in C12, and the rest will be divided evenly among the remaining nuclei.""" library_help = "Name of the library to use to load the network." targets_help = """Target nuclei to use for the algorithm. Will be protons and the endpoint by default.""" tol_help = """Tolerance(s) to use for the algorithm. Will be [1e-3] + [1e-2]*(len(targets)-1) by default for 'drgep'.""" use_mpi_help = "Enable MPI for this run." use_numpy_help = """If the algorithm has a 'use_numpy' option, turn it on. Some algorithms always run as if use_numpy=True.""" sa_help = "Error threshold to use when performing sensitivity analysis." parser = argparse.ArgumentParser(description=description) parser.add_argument('-a', '--algorithm', default='drgep', help=algorithm_help) parser.add_argument('-e', '--endpoint', default=Nucleus('te108'), type=Nucleus, help=endpoint_help) parser.add_argument('-d', '--datadim', nargs=3, default=[4]*3, help=datadim_help) parser.add_argument('-dr', '--drho', type=int, help=drho_help) parser.add_argument('-dt', '--dtemp', help=dtemp_help) parser.add_argument('-dz', '--dmetal', type=int, help=dmetal_help) parser.add_argument('-br', '--brho', nargs=2, type=float, help=brho_help) parser.add_argument('-bt', '--btemp', nargs=2, type=float, help=btemp_help) parser.add_argument('-bz', '--bmetal', nargs=2, type=float, help=bmetal_help) parser.add_argument('-l', '--library', default="rp-process-lib", help=library_help) parser.add_argument('-t', '--targets', nargs='*', type=Nucleus, help=targets_help) parser.add_argument('--tol', nargs='*', type=float, help=tol_help) parser.add_argument('--use_mpi', action='store_true', help=use_mpi_help) parser.add_argument('--use_numpy', action='store_true', help=use_numpy_help) parser.add_argument('-s', '--sens_analysis', default=0.05, type=float, help=sa_help) args = parser.parse_args(sys.argv[1:]) args.algorithm = args.algorithm.lower() if args.algorithm == 'drgep': alg = drgep permute = not args.use_numpy else: raise ValueError(f'Algorithm {args.algorithm} is not supported.') if args.drho: args.datadim[0] = args.drho if args.dtemp: args.datadim[1] = args.dtemp if args.dmetal: args.datadim[2] = args.dmetal if not args.targets: args.targets = [Nucleus('p'), args.endpoint] if not args.tol: if args.algorithm == 'drgep': args.tol = [1e-3] + [1e-2] * (len(args.targets)-1) else: if args.algorithm == 'drgep': if len(args.tol) == 1: args.tol = [args.tol] * len(args.targets) elif len(args.tol) != len(args.targets): raise ValueError( f"For '{args.algorithm}', there should be one tolerance for" " each target nucleus." ) #----------------------------- # Load the network and dataset #----------------------------- net = load_network(args.endpoint, library_name=args.library) data = list(dataset(net, args.datadim, permute, args.brho, args.btemp, args.bmetal)) #----------------------------- # Prepare to run the algorithm #----------------------------- alg_args = \ { 'net': net, 'conds': data, 'targets': args.targets, 'tols': args.tol, 'returnobj': 'nuclei', 'use_mpi': args.use_mpi, 'use_numpy': args.use_numpy } if args.algorithm == 'drgep': args.algorithm = args.algorithm.upper() #------------------ # Run the algorithm #------------------ if args.use_mpi and MPI.COMM_WORLD.Get_rank() == 0: print(f"Commencing reduction with {MPI.COMM_WORLD.Get_size()} processes.\n") elif not args.use_mpi: print("Commencing reduction without MPI.\n") t0 = time.time() nuclei = alg(**alg_args) dt = time.time() - t0 red_net = net.linking_nuclei(nuclei) second_data = list(dataset(net, args.datadim, True, args.brho, args.btemp, args.bmetal)) errfunc = get_errfunc_enuc(net, second_data) if not args.use_mpi or MPI.COMM_WORLD.Get_rank() == 0: print() print(f"{args.algorithm} reduction took {dt:.3f} s.") print("Number of species in full network: ", len(net.unique_nuclei)) print(f"Number of species in {args.algorithm} reduced network: ", len(nuclei)) print("Reduced Network Error:", f"{errfunc(red_net)*100:.2f}%") print() # Perform sensitivity analysis if args.use_mpi: MPI.COMM_WORLD.Barrier() t0 = time.time() red_net = sens_analysis(red_net, errfunc, args.sens_analysis, args.use_mpi) dt = time.time() - t0 if not args.use_mpi or MPI.COMM_WORLD.Get_rank() == 0: print(f"Greedy sensitivity analysis reduction took {dt:.3f} s.") print(f"Number of species in {args.algorithm} + sensitivity analysis reduced network: ", len(red_net.unique_nuclei)) print("Error: ", f"{errfunc(red_net)*100:.2f}%")
if __name__ == "__main__": main()