Source code for tardis.montecarlo.montecarlo_numba.base

from numba import prange, njit, jit
import logging
import numpy as np

from tardis.montecarlo.montecarlo_numba.r_packet import (
    RPacket,
    PacketStatus,
)
from tardis.montecarlo.montecarlo_numba.utils import MonteCarloException

from tardis.montecarlo.montecarlo_numba.numba_interface import (
    PacketCollection,
    VPacketCollection,
    NumbaModel,
    numba_plasma_initialize,
    Estimators,
    configuration_initialize,
)

from tardis.montecarlo import (
    montecarlo_configuration as montecarlo_configuration,
)

from tardis.montecarlo.montecarlo_numba.single_packet_loop import (
    single_packet_loop,
)
from tardis.montecarlo.montecarlo_numba import njit_dict
from numba.typed import List

[docs]def montecarlo_radial1d(model, plasma, runner): packet_collection = PacketCollection( runner.input_r, runner.input_nu, runner.input_mu, runner.input_energy, runner._output_nu, runner._output_energy, ) numba_model = NumbaModel( runner.r_inner_cgs, runner.r_outer_cgs, model.time_explosion.to("s").value, ) numba_plasma = numba_plasma_initialize(plasma, runner.line_interaction_type) estimators = Estimators( runner.j_estimator, runner.nu_bar_estimator, runner.j_blue_estimator, runner.Edotlu_estimator, ) packet_seeds = montecarlo_configuration.packet_seeds number_of_vpackets = montecarlo_configuration.number_of_vpackets ( v_packets_energy_hist, last_interaction_type, last_interaction_in_nu, last_line_interaction_in_id, last_line_interaction_out_id, virt_packet_nus, virt_packet_energies, virt_packet_last_interaction_in_nu, virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, ) = montecarlo_main_loop( packet_collection, numba_model, numba_plasma, estimators, runner.spectrum_frequency.value, number_of_vpackets, packet_seeds, ) runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist runner.last_interaction_type = last_interaction_type runner.last_interaction_in_nu = last_interaction_in_nu runner.last_line_interaction_in_id = last_line_interaction_in_id runner.last_line_interaction_out_id = last_line_interaction_out_id if montecarlo_configuration.VPACKET_LOGGING and number_of_vpackets > 0: runner.virt_packet_nus = np.concatenate( np.array(virt_packet_nus) ).ravel() runner.virt_packet_energies = np.concatenate( np.array(virt_packet_energies) ).ravel() runner.virt_packet_last_interaction_in_nu = np.concatenate( np.array(virt_packet_last_interaction_in_nu) ).ravel() runner.virt_packet_last_interaction_type = np.concatenate( np.array(virt_packet_last_interaction_type) ).ravel() runner.virt_packet_last_line_interaction_in_id = np.concatenate( np.array(virt_packet_last_line_interaction_in_id) ).ravel() runner.virt_packet_last_line_interaction_out_id = np.concatenate( np.array(virt_packet_last_line_interaction_out_id) ).ravel()
[docs]@njit(**njit_dict) def montecarlo_main_loop( packet_collection, numba_model, numba_plasma, estimators, spectrum_frequency, number_of_vpackets, packet_seeds, ): """ This is the main loop of the MonteCarlo routine that generates packets and sends them through the ejecta. Parameters ---------- packet_collection : PacketCollection numba_model : NumbaModel estimators : NumbaEstimators spectrum_frequency : astropy.units.Quantity frequency bins number_of_vpackets : int VPackets released per interaction packet_seeds : numpy.array """ output_nus = np.empty_like(packet_collection.packets_output_nu) last_interaction_types = ( np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 ) output_energies = np.empty_like(packet_collection.packets_output_nu) last_interaction_in_nus = np.empty_like(packet_collection.packets_output_nu) last_line_interaction_in_ids = ( np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 ) last_line_interaction_out_ids = ( np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1 ) v_packets_energy_hist = np.zeros_like(spectrum_frequency) delta_nu = spectrum_frequency[1] - spectrum_frequency[0] # Pre-allocate a list of vpacket collections for later storage vpacket_collections = List() for i in range(len(output_nus)): vpacket_collections.append(VPacketCollection(i, spectrum_frequency, montecarlo_configuration.v_packet_spawn_start_frequency, montecarlo_configuration.v_packet_spawn_end_frequency, number_of_vpackets, montecarlo_configuration.temporary_v_packet_bins)) # Arrays for vpacket logging virt_packet_nus = [] virt_packet_energies = [] virt_packet_last_interaction_in_nu = [] virt_packet_last_interaction_type = [] virt_packet_last_line_interaction_in_id = [] virt_packet_last_line_interaction_out_id = [] for i in prange(len(output_nus)): if montecarlo_configuration.single_packet_seed != -1: seed = packet_seeds[montecarlo_configuration.single_packet_seed] np.random.seed(seed) else: seed = packet_seeds[i] np.random.seed(seed) r_packet = RPacket( numba_model.r_inner[0], packet_collection.packets_input_mu[i], packet_collection.packets_input_nu[i], packet_collection.packets_input_energy[i], seed, i, ) vpacket_collection = vpacket_collections[i] loop = single_packet_loop( r_packet, numba_model, numba_plasma, estimators, vpacket_collection ) # if loop and 'stop' in loop: # raise MonteCarloException output_nus[i] = r_packet.nu last_interaction_in_nus[i] = r_packet.last_interaction_in_nu last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id last_line_interaction_out_ids[i] = r_packet.last_line_interaction_out_id if r_packet.status == PacketStatus.REABSORBED: output_energies[i] = -r_packet.energy last_interaction_types[i] = r_packet.last_interaction_type elif r_packet.status == PacketStatus.EMITTED: output_energies[i] = r_packet.energy last_interaction_types[i] = r_packet.last_interaction_type vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] v_packets_idx = np.floor( (vpackets_nu - spectrum_frequency[0]) / delta_nu ).astype(np.int64) # if we're only in a single-packet mode # if montecarlo_configuration.single_packet_seed == -1: # break for j, idx in enumerate(v_packets_idx): if (vpackets_nu[j] < spectrum_frequency[0]) or ( vpackets_nu[j] > spectrum_frequency[-1] ): continue v_packets_energy_hist[idx] += vpackets_energy[j] if montecarlo_configuration.VPACKET_LOGGING: for vpacket_collection in vpacket_collections: vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx] vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx] virt_packet_nus.append(np.ascontiguousarray(vpackets_nu)) virt_packet_energies.append(np.ascontiguousarray(vpackets_energy)) virt_packet_last_interaction_in_nu.append(np.ascontiguousarray( vpacket_collection.last_interaction_in_nu[ : vpacket_collection.idx ]) ) virt_packet_last_interaction_type.append(np.ascontiguousarray( vpacket_collection.last_interaction_type[ : vpacket_collection.idx ]) ) virt_packet_last_line_interaction_in_id.append(np.ascontiguousarray( vpacket_collection.last_interaction_in_id[ : vpacket_collection.idx ]) ) virt_packet_last_line_interaction_out_id.append(np.ascontiguousarray( vpacket_collection.last_interaction_out_id[ : vpacket_collection.idx ]) ) packet_collection.packets_output_energy[:] = output_energies[:] packet_collection.packets_output_nu[:] = output_nus[:] return ( v_packets_energy_hist, last_interaction_types, last_interaction_in_nus, last_line_interaction_in_ids, last_line_interaction_out_ids, virt_packet_nus, virt_packet_energies, virt_packet_last_interaction_in_nu, virt_packet_last_interaction_type, virt_packet_last_line_interaction_in_id, virt_packet_last_line_interaction_out_id, )