diff --git a/holodeck/__init__.py b/holodeck/__init__.py index 747408d1..ec0edf24 100644 --- a/holodeck/__init__.py +++ b/holodeck/__init__.py @@ -88,15 +88,35 @@ def log_to_file(**kwargs): def set_log_level(level): log.setLevel(level) -# ---- Load cosmology instance - -# NOTE: Must load and initialize cosmology before importing other submodules! -import cosmopy # noqa -cosmo = cosmopy.Cosmology( - h=Parameters.HubbleParam, Om0=Parameters.Omega0, Ob0=Parameters.OmegaBaryon, - size=200, -) -del cosmopy +# ---- DEFERRED Import of cosmology instance for faster import of holodeck +class _CosmoProxy: + def __init__(self): + self._real_cosmo = None + + def _get_real_cosmo(self): + if self._real_cosmo is None: + import cosmopy + + self._real_cosmo = cosmopy.Cosmology( + h=Parameters.HubbleParam, + Om0=Parameters.Omega0, + Ob0=Parameters.OmegaBaryon, + size=200, + ) + return self._real_cosmo + + def __getattr__(self, name): + return getattr(self._get_real_cosmo(), name) + + def __call__(self, *args, **kwargs): + return self._get_real_cosmo()(*args, **kwargs) + + def __repr__(self): + if self._real_cosmo is None: + return "<_CosmoProxy (Uninitialized)>" + return repr(self._real_cosmo) + +cosmo = _CosmoProxy() # ---- Import submodules diff --git a/holodeck/constants.py b/holodeck/constants.py index c5191bce..8bc88c0a 100644 --- a/holodeck/constants.py +++ b/holodeck/constants.py @@ -4,6 +4,16 @@ whenever possible. Constants and units should only be added when they are frequently used (i.e. in multiple files/submodules). +DEVELOPER NOTE: +To preserve fast module import times across holodeck, do NOT import `astropy` or `numpy` +in this file to calculate constants at module initialization. If you need to add a new constant: + 1. Open a temporary terminal or notebook. + 2. Run your astropy conversion (e.g., `import astropy as ap; print(ap.constants.m_e.cgs.value)`). + 3. Copy the raw float literal value and paste it directly below. + +Commented section at the bottom shows how constants were calculated from astropy.constants (ap.constants) +for future reference in case the values are updated. + Notes ----- * [cm] = centimeter @@ -15,44 +25,86 @@ * [K] Kelvin """ -import numpy as np -import astropy as ap -import astropy.constants # noqa # ---- Fundamental Constants -NWTG = ap.constants.G.cgs.value #: Newton's Gravitational Constant [cm^3/g/s^2] -SPLC = ap.constants.c.cgs.value #: Speed of light [cm/s] -MELC = ap.constants.m_e.cgs.value #: Electron Mass [g] -MPRT = ap.constants.m_p.cgs.value #: Proton Mass [g] -QELC = ap.constants.e.gauss.value #: Fundamental unit of charge (electron charge) [fr] -KBOLTZ = ap.constants.k_B.cgs.value #: Boltzmann constant [erg/K] -HPLANCK = ap.constants.h.cgs.value #: Planck constant [erg/s] -SIGMA_SB = ap.constants.sigma_sb.cgs.value #: Stefan-Boltzmann constant [erg/cm^2/s/K^4] -SIGMA_T = ap.constants.sigma_T.cgs.value #: Thomson/Electron -Scattering cross-section [cm^2] - -# ---- Typical astronomy units -MSOL = ap.constants.M_sun.cgs.value #: Solar Mass [g] -LSOL = ap.constants.L_sun.cgs.value #: Solar Luminosity [erg/s] -RSOL = ap.constants.R_sun.cgs.value #: Solar Radius [cm] -PC = ap.constants.pc.cgs.value #: Parsec [cm] -AU = ap.constants.au.cgs.value #: Astronomical Unit [cm] -ARCSEC = ap.units.arcsec.cgs.scale #: arcsecond in radians [] -YR = ap.units.year.to(ap.units.s) #: year [s] -EVOLT = ap.units.eV.to(ap.units.erg) #: Electronvolt in ergs -JY = ap.units.jansky.to(ap.units.g/ap.units.s**2) #: Jansky [erg/s/cm^2/Hz] -KMPERSEC = (ap.units.km / ap.units.s).to(ap.units.cm/ap.units.s) #: km/s [cm/s] +NWTG = 6.674299999999999e-08 #: Newton's Gravitational Constant [cm^3/g/s^2] +SPLC = 29979245800.0 #: Speed of light [cm/s] +MELC = 9.1093837015e-28 #: Electron Mass [g] +MPRT = 1.67262192369e-24 #: Proton Mass [g] +QELC = 4.803204712570263e-10 #: Fundamental unit of charge (electron charge) [fr] +KBOLTZ = 1.380649e-16 #: Boltzmann constant [erg/K] +HPLANCK = 6.62607015e-27 #: Planck constant [erg/s] +SIGMA_SB = 5.6703744191844314e-05 #: Stefan-Boltzmann constant [erg/s/cm^2/K^4] + # ---- Derived Constants -SCHW = 2*NWTG/(SPLC*SPLC) #: Schwarzschild Constant (2*G/c^2) [cm] -EDDT = 4.0*np.pi*NWTG*SPLC*MPRT/SIGMA_T #: Eddington Luminosity prefactor factor [erg/s/g] +# Schwarzschild constant calculated as +# SCHW = 2 * NWTG / (SPLC * SPLC) #: Schwarzschild Constant (2*G/c^2) [cm] +SCHW = 1.4852320538237328e-28 #: Schwarzschild Constant (2*G/c^2) [cm] +# Thomson scattering cross section calculated as: +# SIGMA_T = (8.0 * math.pi / 3.0) * ((QELC * QELC) / (MELC * SPLC * SPLC))**2 +SIGMA_T = 6.6524587321000005e-25 #: Thomson/Electron -Scattering cross-section [cm^2] +# Eddington luminosity calculated as: +# EDDT = 4.0 * math.pi * NWTG * SPLC * MPRT / SIGMA_T #: Eddington Luminosity prefactor factor [erg/s/g] +EDDT = 63219.620781981204 #: Eddington Luminosity prefactor factor [erg/s/g] -# Electron-Scattering Opacity ($\kappa_{es} = n_e \sigma_T / \rho = \mu_e \sigma_T / m_p$) -# Where $\mu_e$ is the mean-mass per electron, for a total mass-density $\rho$. -# KAPPA_ES = SIGMA_T / MPRT #: Electron scattering opacity [cm^2/g] +# ---- Astronomical Constants +MSOL = 1.988409870698051e+33 #: Solar Mass [g] +LSOL = 3.828e+33 #: Solar Luminosity [erg/s] +RSOL = 69570000000.0 #: Solar Radius [cm] +PC = 3.0856775814913674e+18 #: Parsec [cm] +AU = 14959787070000.0 #: Astronomical Unit [cm] +ARCSEC = 4.84813681109536e-06 #: arcsecond in radians [] +YR = 31557600.0 #: year [s] +EVOLT = 1.6021766339999997e-12 #: Electronvolt in ergs +JY = 1e-23 #: Jansky [erg/s/cm^2/Hz] +KMPERSEC = 100000.0 #: km/s [cm/s] DAY = 86400.0 #: Day [s] -MYR = 1.0e6*YR #: Mega-year [s] -GYR = 1.0e9*YR #: Giga-year [s] -KPC = 1.0e3*PC #: Kilo-parsec [cm] -MPC = 1.0e6*PC #: Mega-parsec [cm] -GPC = 1.0e9*PC #: Giga-parsec [cm] +MYR = 31557600000000.0 #: Mega [s] +GYR = 3.15576e+16 #: Giga-year [s] +KPC = 3.0856775814913673e+21 #: Kilo-parsec [cm] +MPC = 3.0856775814913676e+24 #: Mega-parsec [cm] +GPC = 3.085677581491367e+27 #: Giga-parsec [cm] + + +# ----- Constants as calculated from astropy.constants This is a repetition of the above +# ----- and is provided as a reference for the user in case they want to check the values +# ----- with astropy themselves. +# ---- Fundamental Constants +# NWTG = ap.constants.G.cgs.value #: Newton's Gravitational Constant [cm^3/g/s^2] +# SPLC = ap.constants.c.cgs.value #: Speed of light [cm/s] +# MELC = ap.constants.m_e.cgs.value #: Electron Mass [g] +# MPRT = ap.constants.m_p.cgs.value #: Proton Mass [g] +# QELC = ap.constants.e.gauss.value #: Fundamental unit of charge (electron charge) [fr] +# KBOLTZ = ap.constants.k_B.cgs.value #: Boltzmann constant [erg/K] +# HPLANCK = ap.constants.h.cgs.value #: Planck constant [erg/s] +# SIGMA_SB = ap.constants.sigma_sb.cgs.value #: Stefan-Boltzmann constant [erg/cm^2/s/K^4] +# SIGMA_T = ap.constants.sigma_T.cgs.value #: Thomson/Electron -Scattering cross-section [cm^2] + +# # ---- Typical astronomy units +# MSOL = ap.constants.M_sun.cgs.value #: Solar Mass [g] +# LSOL = ap.constants.L_sun.cgs.value #: Solar Luminosity [erg/s] +# RSOL = ap.constants.R_sun.cgs.value #: Solar Radius [cm] +# PC = ap.constants.pc.cgs.value #: Parsec [cm] +# AU = ap.constants.au.cgs.value #: Astronomical Unit [cm] +# ARCSEC = ap.units.arcsec.cgs.scale #: arcsecond in radians [] +# YR = ap.units.year.to(ap.units.s) #: year [s] +# EVOLT = ap.units.eV.to(ap.units.erg) #: Electronvolt in ergs +# JY = ap.units.jansky.to(ap.units.g/ap.units.s**2) #: Jansky [erg/s/cm^2/Hz] +# KMPERSEC = (ap.units.km / ap.units.s).to(ap.units.cm/ap.units.s) #: km/s [cm/s] + +# # ---- Derived Constants +# SCHW = 2*NWTG/(SPLC*SPLC) #: Schwarzschild Constant (2*G/c^2) [cm] +# EDDT = 4.0*np.pi*NWTG*SPLC*MPRT/SIGMA_T #: Eddington Luminosity prefactor factor [erg/s/g] + +# # Electron-Scattering Opacity ($\kappa_{es} = n_e \sigma_T / \rho = \mu_e \sigma_T / m_p$) +# # Where $\mu_e$ is the mean-mass per electron, for a total mass-density $\rho$. +# # KAPPA_ES = SIGMA_T / MPRT #: Electron scattering opacity [cm^2/g] + +# DAY = 86400.0 #: Day [s] +# MYR = 1.0e6*YR #: Mega-year [s] +# GYR = 1.0e9*YR #: Giga-year [s] +# KPC = 1.0e3*PC #: Kilo-parsec [cm] +# MPC = 1.0e6*PC #: Mega-parsec [cm] +# GPC = 1.0e9*PC #: Giga-parsec [cm] diff --git a/holodeck/logger.py b/holodeck/logger.py index c6d81b6a..17e1a8d6 100644 --- a/holodeck/logger.py +++ b/holodeck/logger.py @@ -9,6 +9,7 @@ from pathlib import Path import logging from logging import DEBUG, INFO, WARNING, ERROR # noqa import these for easier access internally +import os import sys from holodeck import LOG_SUFFIX, LOG_FILENAME_WITH_TIME_STAMP, _PATH_LOGS logging.getLogger().addHandler(logging.NullHandler()) @@ -50,13 +51,22 @@ def get_logger(name=None, level_stream=WARNING, tostr=sys.stdout, tofile=None, l """ comm_rank = None - try: - from mpi4py import MPI - comm = MPI.COMM_WORLD - if comm.size > 1: - comm_rank = comm.rank - except ModuleNotFoundError: - pass + + env_rank = os.environ.get("OMPI_COMM_WORLD_RANK") or os.environ.get("PMI_RANK") or os.environ.get("PMIX_RANK") + env_size = os.environ.get("OMPI_COMM_WORLD_SIZE") or os.environ.get("PMI_SIZE") or os.environ.get("PMIX_SIZE") + + if env_size is not None: + if int(env_size) > 1 and env_rank is not None: + comm_rank = int(env_rank) + + elif 'mpi4py' in sys.modules: + try: + from mpi4py import MPI + comm = MPI.COMM_WORLD + if comm.size > 1: + comm_rank = comm.rank + except Exception: + pass if name is None: name = 'holodeck' diff --git a/holodeck/sams/simple_sam.py b/holodeck/sams/simple_sam.py index 4af1ec0e..993bfa12 100644 --- a/holodeck/sams/simple_sam.py +++ b/holodeck/sams/simple_sam.py @@ -5,9 +5,6 @@ from holodeck import cosmo, utils, gravwaves from holodeck.constants import MSOL, GYR, MPC -_AGE_UNIVERSE_GYR = cosmo.age(0.0).to('Gyr').value # [Gyr] ~ 13.78 - - class Simple_SAM: def __init__( @@ -144,7 +141,8 @@ def _zprime(self, mgal, qgal, redz): age = cosmo.age(redz).to('s').value new_age = age + tau0 redz_prime = -1.0 * np.ones_like(new_age) - idx = (new_age < _AGE_UNIVERSE_GYR * GYR) + age_universe_gyr = cosmo.age(0.0).to('Gyr').value + idx = (new_age < age_universe_gyr * GYR) redz_prime[idx] = cosmo.tage_to_z(new_age[idx]) return redz_prime diff --git a/holodeck/utils.py b/holodeck/utils.py index 87630674..4e50f340 100644 --- a/holodeck/utils.py +++ b/holodeck/utils.py @@ -25,17 +25,27 @@ # except ImportError: # from typing_extensions import ParamSpec -import h5py -import numba import numpy as np import numpy.typing as npt -import scipy as sp -import scipy.stats # noqa -import scipy.special # noqa from holodeck import log, cosmo from holodeck.constants import NWTG, SCHW, SPLC, YR, GYR, MPC, PC, EDDT +class _LazyNJIT: + """A lazy proxy decorator for numba.njit to prevent compiler loading at import.""" + def __init__(self, func): + self._func = func + self._compiled_func = None + + def __call__(self, *args, **kwargs): + # The very first time the function is called, import numba and compile it + if self._compiled_func is None: + import numba + self._compiled_func = numba.njit(self._func) + return self._compiled_func(*args, **kwargs) +def lazy_njit(func): + return _LazyNJIT(func) + # [Sesana2004]_ Eq.36 _GW_SRC_CONST = 8 * np.power(NWTG, 5/3) * np.power(np.pi, 2/3) / np.sqrt(10) / np.power(SPLC, 4) _GW_DADT_SEP_CONST = - 64 * np.power(NWTG, 3) / 5 / np.power(SPLC, 5) @@ -43,7 +53,9 @@ # [EN2007]_, Eq.2.2 _GW_LUM_CONST = (32.0 / 5.0) * np.power(NWTG, 7.0/3.0) * np.power(SPLC, -5.0) -_AGE_UNIVERSE_GYR = cosmo.age(0.0).to('Gyr').value # [Gyr] ~ 13.78 +@functools.lru_cache(maxsize=1) +def get_age_universe_gyr(): + return cosmo.age(0.0).to('Gyr').value _DFDM_CONST = np.sqrt(NWTG) / (4.0 * np.pi) @@ -194,6 +206,7 @@ def load_hdf5(fname, keys=None): specifically everything returned from `hdf5.File.keys()`. """ + import h5py squeeze = False if (keys is not None) and np.isscalar(keys): keys = [keys] @@ -530,11 +543,12 @@ def scatter_redistribute_densities(cents, dens, dist=None, scatter=None, axis=0) Array with resitributed values. Same shape as input `dens`. """ + from scipy import stats as _stats if (dist is None) == (scatter is None): raise ValueError(f"One and only one of `dist` ({dist}) and `scatter` ({scatter}) must be provided!") if dist is None: - dist = sp.stats.norm(loc=0.0, scale=scatter) + dist = _stats.norm(loc=0.0, scale=scatter) log_cents = np.log10(cents) num = log_cents.size @@ -1009,7 +1023,8 @@ def quantiles( raise ValueError(err) if percs is None: - percs = sp.stats.norm.cdf(sigmas) + from scipy import stats as _stats + percs = _stats.norm.cdf(sigmas) if np.ndim(values) > 1: if axis is None: @@ -1052,7 +1067,7 @@ def quantiles( return percs def random_power(extr, pdf_index, size=1): - """Draw from a power-law PDF with the given index, between the given extrema. + r"""Draw from a power-law PDF with the given index, between the given extrema. NOTE: The power-law index must correspond to the power-law index of $\frac{dN}{dx}$. You may need to convert, e.g. $dN/dx = \frac{dN}{d \ln x} \frac{1}{x}$. @@ -1151,7 +1166,7 @@ def stats(vals: npt.ArrayLike, percs: Optional[npt.ArrayLike] = None, prec: int raise TypeError(f"`vals` (shape={np.shape(vals)}) is not iterable!") if percs is None: - percs = [sp.stats.norm.cdf(1), 0.95, 1.0] + percs = [0.8413447460685429, 0.95, 1.0] # 0.841... is percentile for 1 sigma above mean in Gaussian percs = np.array(percs) percs = np.concatenate([1-percs[::-1], [0.5], percs]) @@ -1267,7 +1282,8 @@ def trapz_loglog( log.error(err) raise ValueError(err) - newy = sp.interpolate.PchipInterpolator(np.log10(xx), np.log10(yy), extrapolate=False) + import scipy.interpolate as _interpolate + newy = _interpolate.PchipInterpolator(np.log10(xx), np.log10(yy), extrapolate=False) newy = newy(bounds) ii = np.searchsorted(xx, bounds) @@ -1483,12 +1499,13 @@ def fit_gaussian(xx, yy, guess=None): Covariance matrix of best fit parameters. """ + import scipy.optimize as _optimize if guess is None: amp = np.max(yy) mean = np.sum(xx * yy) / np.sum(yy) stdev = std(xx, yy) guess = [amp, mean, stdev] - popt, pcov = sp.optimize.curve_fit(_func_gaussian, xx, yy, p0=guess, maxfev=10000) + popt, pcov = _optimize.curve_fit(_func_gaussian, xx, yy, p0=guess, maxfev=10000) return popt, pcov @@ -1506,8 +1523,8 @@ def fit_powerlaw(xx, yy, init=[-15.0, -2.0/3.0]): plaw """ - - popt, pcov = sp.optimize.curve_fit(_func_line, np.log10(xx), np.log10(yy), p0=init, maxfev=10000) + import scipy.optimize as _optimize + popt, pcov = _optimize.curve_fit(_func_line, np.log10(xx), np.log10(yy), p0=init, maxfev=10000) # log10_amp = popt[0] # gamma = popt[1] @@ -1530,8 +1547,9 @@ def fit_func(xx, log10_amp, index): amp = 10.0 ** log10_amp yy = _func_powerlaw_psd(xx, fref, amp, index) return np.log10(yy) - - popt, pcov = sp.optimize.curve_fit( + + import scipy.optimize as _optimize + popt, pcov = _optimize.curve_fit( fit_func, xx, np.log10(yy), p0=init, maxfev=10000, full_output=False ) @@ -1553,8 +1571,9 @@ def fit_powerlaw_fixed_index(xx, yy, index=-2.0/3.0, init=[-15.0]): plaw """ + import scipy.optimize as _optimize _func_fixed = lambda xx, amp: _func_line(xx, amp, index) - popt, pcov = sp.optimize.curve_fit(_func_fixed, np.log10(xx), np.log10(yy), p0=init, maxfev=10000) + popt, pcov = _optimize.curve_fit(_func_fixed, np.log10(xx), np.log10(yy), p0=init, maxfev=10000) log10_amp = popt[0] return log10_amp @@ -1610,8 +1629,8 @@ def fit_func(xx, log10_amp, *args): amp = 10.0 ** log10_amp yy = _func_turnover_psd(xx, fref, amp, *args) return np.log10(yy) - - popt, pcov = sp.optimize.curve_fit( + import scipy.optimize as _optimize + popt, pcov = _optimize.curve_fit( fit_func, xx, np.log10(yy), p0=init, maxfev=10000, full_output=False ) @@ -1905,6 +1924,7 @@ def redz_after(time, redz=None, age=None): Redshift of the Universe after the given amount of time. """ + _AGE_UNIVERSE_GYR = get_age_universe_gyr() if (redz is None) == (age is None): raise ValueError("One of `redz` and `age` must be provided (and not both)!") @@ -2263,9 +2283,9 @@ def gw_freq_dist_func(nn, ee=0.0, recursive=True): GW Frequency distribution function g(n,e). """ - # Calculate with non-zero eccentrictiy - bessel = sp.special.jn + import scipy.special as _special + bessel = _special.jn ne = nn*ee n2 = np.square(nn) jn_m2 = bessel(nn-2, ne) @@ -2558,7 +2578,7 @@ def char_strain_to_strain_amp(hc, fc, df): return hs -@numba.njit +@lazy_njit def _gw_ecc_func(eccen): """GW Hardening rate eccentricitiy dependence F(e).