Source code for pynucastro.rates.tabular_rate

import math
import re
from enum import Enum
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

import pynucastro.numba_util as numba
from pynucastro.nucdata import Nucleus, UnsupportedNucleus
from pynucastro.numba_util import jitclass
from pynucastro.rates.files import RateFileError, _find_rate_file
from pynucastro.rates.rate import Rate, RateSource


[docs] class TableIndex(Enum): """An enum-like container for indexing the electron-capture tables """ RHOY = 0 T = 1 MU = 2 DQ = 3 VS = 4 RATE = 5 NU = 6 GAMMA = 7
[docs] @jitclass([ ('data', numba.float64[:, :]), ('table_rhoy_lines', numba.int32), ('table_temp_lines', numba.int32), ('rhoy', numba.float64[:]), ('temp', numba.float64[:]) ]) class TableInterpolator: """A class that holds a pointer to the table data and methods that allow us to interpolate a variable Parameters ---------- table_rhoy_lines : int the number of the (ρ Y_e) values where the rate is tabulated table_temp_lines : int the number of T values where the rate is tabulated table_data : numpy.ndarray a 2D array giving the tabulated rate data of the form (index, component) where index is a 1D flattened representation of (rhoY, T). """ def __init__(self, table_rhoy_lines, table_temp_lines, table_data): self.data = table_data self.table_rhoy_lines = table_rhoy_lines self.table_temp_lines = table_temp_lines # for easy indexing, store a 1-d array of T and rhoy self.rhoy = self.data[::self.table_temp_lines, TableIndex.RHOY.value] self.temp = self.data[0:self.table_temp_lines, TableIndex.T.value] def _get_logT_idx(self, logt0): """Find the index into the temperatures such that T[i-1] < t0 <= T[i]. We return i-1 here, corresponding to the lower value. Note: we work in terms of log10() Parameters ---------- logt0 : float log10(temperature) to interpolate at Returns ------- int """ max_idx = len(self.temp) - 1 return max(0, min(max_idx, np.searchsorted(self.temp, logt0)) - 1) def _get_logrhoy_idx(self, logrhoy0): """Return the index into rhoY such that rhoY[i-1] < rhoy0 <= rhoY[i]. We return i-1 here, corresponding to the lower value. Note: we work in terms of log10() Parameters ---------- logrhoy0 : float log10(ρ Y_e) to interpolate at Returns ------- int """ max_idx = len(self.rhoy) - 1 return max(0, min(max_idx, np.searchsorted(self.rhoy, logrhoy0)) - 1) def _rhoy_T_to_idx(self, irhoy, jtemp): """Given a pair (irhoy, jtemp) into the table, return the 1D index into the underlying data array assuming row-major ordering Parameters ---------- irhoy : int the index in the (ρ Y_e) dimension ijtemp : int the index into the T dimension Returns ------- int """ return irhoy * self.table_temp_lines + jtemp def interpolate(self, logrhoy, logT, component): """Given logrhoy and logT, do bilinear interpolation to find the value of the data component in the table Parameters ---------- logrhoy : float log10(ρ Y_e) to interpolate at logT : float log10(T) to interpolate at component : int the component from the data table we are interpolating. This should correspond to a :py:class:`TableIndex` component. Returns ------- float """ # We are going to do bilinear interpolation. We create a # polynomial of the form: # # f = A [log(rho) - log(rho_i)] [log(T) - log(T_j)] + # B [log(rho) - log(rho_i)] + # C [log(T) - log(T_j)] + # D # # we then find the i,j such that our point is in the # box with corners (i,j) to (i+1,j+1), and solve for # A, B, C, D # find the T and rhoY in the data table corresponding to the # lower left if logT < self.temp.min() or logT > self.temp.max(): raise ValueError("temperature out of table bounds") if logrhoy < self.rhoy.min() or logrhoy > self.rhoy.max(): raise ValueError("rhoy out of table bounds") irhoy = self._get_logrhoy_idx(logrhoy) jT = self._get_logT_idx(logT) # note: rhoy and T are already stored as log dlogrho = self.rhoy[irhoy+1] - self.rhoy[irhoy] dlogT = self.temp[jT+1] - self.temp[jT] # get the data at the 4 points idx = self._rhoy_T_to_idx(irhoy, jT) f_ij = self.data[idx, component] idx = self._rhoy_T_to_idx(irhoy+1, jT) f_ip1j = self.data[idx, component] idx = self._rhoy_T_to_idx(irhoy, jT+1) f_ijp1 = self.data[idx, component] idx = self._rhoy_T_to_idx(irhoy+1, jT+1) f_ip1jp1 = self.data[idx, component] D = f_ij C = (f_ijp1 - f_ij) / dlogT B = (f_ip1j - f_ij) / dlogrho A = (f_ip1jp1 - B * dlogrho - C * dlogT - D) / (dlogrho * dlogT) r = (A * (logrhoy - self.rhoy[irhoy]) * (logT - self.temp[jT]) + B * (logrhoy - self.rhoy[irhoy]) + C * (logT - self.temp[jT]) + D) return r
[docs] class TabularRate(Rate): """A rate tabulated in terms of log10(ρ Y_e) and log10(T). Parameters ---------- rfile : str, pathlib.Path, io.StringIO the file containing the data table """ def __init__(self, rfile=None): super().__init__() self.rate_eval_needs_rho = True self.rate_eval_needs_comp = True self.rfile_path = None self.rfile = None self.source = None if isinstance(rfile, (str, Path)): rfile = Path(rfile) self.rfile_path = _find_rate_file(rfile) self.source = RateSource.source(self.rfile_path.parent.name) self.rfile = rfile.name self.fname = None self.label = "tabular" self.tabular = True # we should initialize this somehow self.weak_type = "" self._read_from_file(self.rfile_path) self._set_rhs_properties() self._set_screening() self._set_print_representation() # store the extrema of the thermodynamics _rhoy = self.tabular_data_table[::self.table_temp_lines, TableIndex.RHOY.value] _temp = self.tabular_data_table[0:self.table_temp_lines, TableIndex.T.value] self.table_Tmin = 10.0**(_temp.min()) self.table_Tmax = 10.0**(_temp.max()) self.table_rhoYmin = 10.0**(_rhoy.min()) self.table_rhoYmax = 10.0**(_rhoy.max()) self.interpolator = TableInterpolator(self.table_rhoy_lines, self.table_temp_lines, self.tabular_data_table) def __hash__(self): return hash(self.__repr__()) def __eq__(self, other): """Determine whether two Rate objects are equal. They are equal if they contain identical reactants and products. """ if not isinstance(other, TabularRate): return False return self.reactants == other.reactants and self.products == other.products def __add__(self, other): raise NotImplementedError("addition not defined for tabular rates") def _read_from_file(self, table_file): """Given a filename, read rate data from the file. Parameters ---------- table_file : str, pathlib.Path The file object that contains the table data """ # just store the filename as the original source self.original_source = f"{table_file}" # set weak type if "electroncapture" in str(table_file): self.weak_type = "electron_capture" elif "betadecay" in str(table_file): self.weak_type = "beta_decay" # for backwards compatibility, we'll set a chapter to "t" self.chapter = "t" # read in the table data # there are a few header lines that start with "!", which we skip, # expect for the very first, which defines the nuclei in the form # reactant -> product t_data2d = [] reactant = None product = None header_lines = 0 with open(table_file) as tabular_file: for i, line in enumerate(tabular_file): if i == 0: try: # we have a line of the form: # !65ni -> 65co, e- capture # split it g = re.match(r"!([\da-zA-Z]*) \-\> ([\da-zA-Z]*)[\w,\-]*", line) reactant = g.group(1) product = g.group(2) except AttributeError: # we have a line including spins, of the form: # !17F (5/2+, 1/2+) -> 17O e-capture with screening effects # this is mainly Suzuki rates. The stuff in the (...) giving # the spins can be complicated, but the key is that it is in # parentheses. g = re.match(r"!([\da-zA-Z]*)\s*\([\w\:\=\d/\+,\.\s\_\{\}]*\)\s+\-\> ([\da-zA-Z]*)[\w,\-]*", line) reactant = g.group(1) product = g.group(2) header_lines += 1 continue if line.startswith("!"): header_lines += 1 continue line = line.strip() # skip empty lines if not line: continue # split the column values on whitespace t_data2d.append(line.split()) try: self.reactants.append(Nucleus.from_cache(reactant.lower())) self.products.append(Nucleus.from_cache(product.lower())) except UnsupportedNucleus as ex: raise RateFileError(f'Nucleus objects could not be identified in {self.original_source}') from ex self.table_file = table_file # convert the nested list of string values into a numpy float array self.tabular_data_table = np.array(t_data2d, dtype=np.float64) # get the number of rhoy lines self.table_header_lines = header_lines self.table_rhoy_lines = len(np.unique(self.tabular_data_table[:, 0])) self.table_temp_lines = len(np.unique(self.tabular_data_table[:, 1])) self.table_num_vars = 6 # Hard-coded number of variables in tables for now. self.table_index_name = f'j_{self.reactants[0]}_{self.products[0]}' self.labelprops = 'tabular' # since the reactants and products were only now set, we need # to recompute Q -- this is used for finding rate pairs self._set_q() def _set_rhs_properties(self): """Compute statistical prefactor and density exponent from the reactants.""" self.prefactor = 1.0 # this is 1/2 for rates like a + a (double counting) self.inv_prefactor = 1 if self.use_identical_particle_factor: for r in set(self.reactants): self.inv_prefactor = self.inv_prefactor * math.factorial(self.reactants.count(r)) self.prefactor = self.prefactor/float(self.inv_prefactor) self.dens_exp = len(self.reactants)-1 def _set_screening(self): """Tabular rates are not currently screened (they are e-capture or beta-decay)""" self.ion_screen = [] self.symmetric_screen = [] if not self.fname: # This is used to determine which rates to detect as the same reaction # from multiple sources in a Library file, so it should not be unique # to a given source, e.g. wc12, but only unique to the reaction. reactants_str = '_'.join([repr(nuc) for nuc in self.reactants]) products_str = '_'.join([repr(nuc) for nuc in self.products]) self.fname = f'{reactants_str}__{products_str}'
[docs] def get_rate_id(self): """Get an identifying string for this rate. Returns ------- str """ ssrc = 'tabular' return f'{self.rid} <{self.label.strip()}_{ssrc}>'
[docs] def function_string_py(self): """Construct the python function that computes the rate. Returns ------- str """ fstring = "" fstring += "@numba.njit()\n" fstring += f"def {self.fname}(rate_eval, T, rho, Y):\n" fstring += f" # {self.rid}\n" fstring += " rhoY = rho * ye(Y)\n" fstring += f" {self.fname}_interpolator = TableInterpolator(*{self.fname}_info)\n" fstring += f" r = {self.fname}_interpolator.interpolate(np.log10(rhoY), np.log10(T), TableIndex.RATE.value)\n" fstring += f" rate_eval.{self.fname} = 10.0**r\n\n" return fstring
[docs] def eval(self, T, *, rho=None, comp=None): """Evaluate the reaction rate. Parameters ---------- T : float the temperature to evaluate the rate at rho : float the density to evaluate the rate at (not needed for ReacLib rates). comp : float the composition (of type :py:class:`Composition <pynucastro.networks.rate_collection.Composition>`) to evaluate the rate with (not needed for ReacLib rates). Returns ------- float """ rhoY = rho * comp.ye r = self.interpolator.interpolate(np.log10(rhoY), np.log10(T), TableIndex.RATE.value) return 10.0**r
[docs] def get_nu_loss(self, T, *, rho=None, comp=None): """Evaluate the neutrino loss for the rate. Parameters ---------- T : float the temperature to evaluate the rate at rho : float the density to evaluate the rate at. comp : float the composition (of type :py:class:`Composition <pynucastro.networks.rate_collection.Composition>`) to evaluate the rate with. Returns ------- float """ rhoY = rho * comp.ye r = self.interpolator.interpolate(np.log10(rhoY), np.log10(T), TableIndex.NU.value) return 10**r
[docs] def plot(self, *, Tmin=None, Tmax=None, rhoYmin=None, rhoYmax=None, color_field='rate', figsize=(10, 10)): """Plot the rate or neutrino loss in the log10(ρ Y_e) and log10(T) plane. Parameters ---------- Tmin : float minimum temperature for the plot Tmax : float maximum temperature for the plot rhoYmin : float minimum (ρ Y_e) for the plto rhoYmax : float maximum (ρ Y_e) for the plto color_field : str the field to plot. Possible values are "rate" or "nu_loss" figsize : tuple the horizontal, vertical size (in inches) for the plot Returns ------- matplotlib.figure.Figure """ fig, ax = plt.subplots(figsize=figsize) if Tmin is None: Tmin = self.table_Tmin if Tmax is None: Tmax = self.table_Tmax if rhoYmin is None: rhoYmin = self.table_rhoYmin if rhoYmax is None: rhoYmax = self.table_rhoYmax data = self.tabular_data_table inde1 = data[:, TableIndex.T.value] <= np.log10(Tmax) inde2 = data[:, TableIndex.T.value] >= np.log10(Tmin) inde3 = data[:, TableIndex.RHOY.value] <= np.log10(rhoYmax) inde4 = data[:, TableIndex.RHOY.value] >= np.log10(rhoYmin) data_heatmap = data[inde1 & inde2 & inde3 & inde4].copy() rows, row_pos = np.unique(data_heatmap[:, 0], return_inverse=True) cols, col_pos = np.unique(data_heatmap[:, 1], return_inverse=True) pivot_table = np.zeros((len(rows), len(cols)), dtype=data_heatmap.dtype) if color_field == 'rate': icol = TableIndex.RATE.value title = f"{self.weak_type} rate in log10(1/s)" cmap = 'magma' elif color_field == 'nu_loss': icol = TableIndex.NU.value title = "neutrino energy loss rate in log10(erg/s)" cmap = 'viridis' else: raise ValueError("color_field must be either 'rate' or 'nu_loss'.") try: pivot_table[row_pos, col_pos] = data_heatmap[:, icol] except ValueError: print("Divide by zero encountered in log10\nChange the scale of T or rhoY") im = ax.imshow(pivot_table, cmap=cmap, origin="lower", extent=[np.log10(Tmin), np.log10(Tmax), np.log10(rhoYmin), np.log10(rhoYmax)]) fig.colorbar(im, ax=ax) ax.set_xlabel(r"$\log(T)$ [K]") ax.set_ylabel(r"$\log(\rho Y_e)$ [g/cm$^3$]") ax.set_title(fr"{self.pretty_string}" + "\n" + title) return fig