"""Module for running forward models across multiple t-T paths and VTK meshes.
This module includes a function for running forward models across multiple
user-provided t-T paths (run_tt_paths) and a function to run forward models
across the t-T paths experienced by different particles across a series of
VTK meshes (run_vtk)
"""
import gc
import os
import shutil
import warnings
from contextlib import suppress
import numpy as np
import pyvista as pv
from joblib import Parallel, delayed
from scipy.spatial import KDTree
from tqdm import tqdm
from gdtchron import aft, he
def run_particle_he(particle_id, inputs, calc_age, interpolate_vals,
dtype=np.float32):
"""Calculate profile of x values for a particular ASPECT particle.
Function to calculate the profile of x values (He concentrations times
node positions) across all nodes within a hypothetical grain found in
a given particle. This function's primary purpose is to be run in
parallel by run_vtk.
Parameters
----------
particle_id : int
ID corresponding to the particle to get the He profile of
inputs : tuple
k : any
Unused parameter (included here for symmetry with the inputs of
run_particle_aft)
positions : PyVista array
x, y, z coordinates of each particle in current mesh
tree : SciPy KDTree or None
K-d tree containing the positions of particles from the previous
timestep. Unused (and typically set to None) if interpolate_vals
is False.
ids : PyVista array
IDs for all particles from the current timestep
old_ids : PyVista array
IDs for all particles from the previous timestep
tree_ids : PyVista array
IDs for all particles with profiles from the previous timestep. Not
used and typically set to None if interpolate_vals is True
mesh_temps : PyVista array
Temperatures (K) for all particles from the current timestep
old_temps : PyVista array
Temperatures (K) for all particles from the previous timestep
old_profiles : NumPy ndarray of NumPy arrays of floats
Profiles of x values for all particles.
time_interval : float
Time elapsed between mesh files (Myr)
system : str
Isotopic system to model. Valid options are 'AHe' (apatite He) or
'ZHe' (zircon He)
num_nodes : int
Number of nodes to use within each profile
model_inputs : tuple
u : float
U concentration (ppm).
th : float
Th concentration (ppm).
radius : float
Radius of the grain (micrometers).
calc_age : bool
Boolean indicating whether to calculate age of particle. If False,
age is returned as np.nan. Note that setting this to False will not
substantially improve the speed of calculations for this function.
interpolate_vals : bool
Boolean indicating whether to interpolate He data from nearest neighbor
of the particle if the particle itself lacks He data. If False and the
particle is missing He data, an age of np.nan and a profile filled with
np.inf are returned.
dtype : type
Type of numbers used for calculations (default: np.float32). 32-bit
floats are preferred to save memory.
Returns
-------
age : float
Age of the particle. This equals np.nan if calc_age is False or there
is an issue obtaining the age of the particle.
x : NumPy array of floats
Matrix x solved for using Equation 21 in Ketcham (2005). In that
equation, x is referred to as u. We use x here to avoid confusing this
variable for uranium (u).
Equivalent to the He concentration (mol / g) times the node position
(micrometers).
"""
# Unpack inputs
(k, positions, tree, ids, old_ids, tree_ids, mesh_temps, old_temps,
old_profiles, time_interval, system, num_nodes, (u, th, radius)) = inputs
# Get old profile and temperature for current particle if present
profile = old_profiles[particle_id == old_ids]
particle_start_temp = old_temps[particle_id == old_ids]
# Create variable to track if missing old data for particle
missing = False
# If array is empty, assign np.nan
# If the initial array is filled with np.inf (i.e, we have a bugged
# particle), then return (NaN, initial array)
# Otherwise, assign new value from old profile
if profile.size == 0:
profile = np.empty(num_nodes, dtype=dtype)
profile.fill(np.nan)
missing = True
elif profile[0][0] == dtype(np.inf):
age = np.nan
return (age, profile[0])
# Get particle temperature
particle_end_temp = mesh_temps[ids == particle_id]
# If particle not found, don't attempt to calculate profile or age
if particle_end_temp.size == 0:
x = np.empty(num_nodes, dtype=dtype)
x.fill(dtype(np.inf))
age = np.nan
return (age, x)
if missing:
# Use previous He from nearest neighbor in previous timestep if needed
if interpolate_vals:
# Get particle position
particle_position = positions[ids == particle_id]
# Find closest particle
distance, index = tree.query(particle_position)
# Get id of closest particle
neighbor_id = tree_ids[index]
# Get profile of closest particle
# Note: Sometimes particles can get bugged and have duplicate ids
try:
profile = old_profiles[neighbor_id == old_ids]
except Exception:
warnings.warn("Warning: 2+ particles likely have the same ID",
stacklevel=2)
x = np.empty(num_nodes, dtype=dtype)
x.fill(dtype(np.inf))
age = np.nan
return (age, x)
# Get temp of closest particle
particle_start_temp = old_temps[neighbor_id == old_ids]
# If interpolate turned off, return original profile of np.inf
else:
x = np.empty(num_nodes, dtype=dtype)
x.fill(dtype(np.inf))
age = np.nan
return (age, x)
# Double checking that interpolated particle isn't bugged
if profile[0][0] == dtype(np.inf):
age = np.nan
return (age, profile[0])
# Passing start and end temperatures to forward model
particle_temps = np.array([particle_start_temp[0], particle_end_temp[0]])
particle_tsteps = np.array([time_interval, 0])
age, age_unc, he_tot, pos, v, x = \
he.forward_model_he(temps=particle_temps,
tsteps=particle_tsteps,
system=system,
u=u,
th=th,
radius=radius,
nodes=num_nodes,
initial_x=profile.flatten(),
return_all=True)
if calc_age:
return (age, x)
else:
age = np.nan
return (age, x)
def run_particle_ft(particle_id, inputs, calc_age, interpolate_vals):
"""Calculate FT reduced lengths for a particular ASPECT particle.
Function to calculate the reduced lengths (unitless) within a hypothetical
grain found in a given particle. This function's primary purpose is to be
run in parallel by run_vtk.
Parameters
----------
particle_id : int
ID corresponding to the particle to get the reduced lengths of
inputs : tuple
k : int
Index of the current timestep/mesh being processed.
positions : PyVista array
x, y, z coordinates of each particle in current mesh
tree : SciPy KDTree or None
K-d tree containing the positions of particles from the previous
timestep. Unused (and typically set to None) if interpolate_vals
is False.
ids : PyVista array
IDs for all particles from the current timestep
old_ids : PyVista array
IDs for all particles from the previous timestep
tree_ids : PyVista array
IDs for all particles with profiles from the previous timestep. Not
used (and typically set to False) if interpolate_vals is True.
mesh_temps : PyVista array
Temperatures (K) for all particles from the current timestep
old_temps : PyVista array
Temperatures (K) for all particles from the previous timestep
old_annealing_arrays : NumPy ndarray of NumPy arrays of floats
r values for all particles from the previous timestep
time_interval : float
Time elapsed between mesh files (Myr)
system : str
Isotopic system to model. Not used for FT system (but included as a
parameter for symmetry with run_particle_he)
r_length : int
Length of the r arrays for particles that have them
model_inputs : tuple
dpar : float
Etch figure length (micrometers).
annealing_model : str
Annealing model to use. Currently, the only acceptable value is
'Ketcham99', which corresponds to the fanning curvilinear model
from Ketcham et al. (1999).
calc_age : bool
Boolean indicating whether to calculate age of particle. If False,
age is returned as np.nan.
interpolate_vals : bool
Boolean indicating whether to interpolate FT data from nearest neighbor
of the particle if the particle itself lacks FT data. If False and the
particle is missing FT data, an age of np.nan and a profile filled with
np.inf are returned.
Returns
-------
age : float
Age of the particle. This equals np.nan if calc_age is False or there
is an issue obtaining the age of the particle
r : NumPy array of floats
Updated reduced lengths (unitless) for a hypothetical grain located
within this particle.
"""
# Use float64 to match nans
dtype = np.float64
# Unpack inputs
(k, positions, tree, ids, old_ids, tree_ids, mesh_temps, old_temps,
old_annealing_arrays, time_interval, system,
r_length, (dpar, annealing_model)) = inputs
ft_constants = {'Ketcham99': aft.KETCHAM_99_FC}
# Get old profile and temperature for current particle if present
r_initial = old_annealing_arrays[particle_id == old_ids]
particle_start_temp = old_temps[particle_id == old_ids]
# Create variable to track if missing old data for particle
missing = False
# If array is empty, assign np.nan
# If r_initial is filled with np.inf (i.e, we have a bugged particle),
# return (NaN, r_initital)
if r_initial.size == 0:
r_initial = np.empty(r_length, dtype=dtype)
r_initial.fill(np.nan)
missing = True
elif r_initial[0][0] == dtype(np.inf):
age = np.nan
return (age, r_initial[0])
else:
r_initial = r_initial[0]
# Get final particle temperature
particle_end_temp = mesh_temps[ids == particle_id]
# If particle not found, don't attempt to calculate profile or age
if particle_end_temp.size == 0:
x = np.empty(r_length, dtype=dtype)
x.fill(dtype(np.inf))
age = np.nan
return (age, x)
if missing:
# Use annealing from nearest neighbor in previous timestep if needed
if interpolate_vals:
# Get particle position
particle_position = positions[ids == particle_id]
# Find closest particle
distance, index = tree.query(particle_position)
# Get id of closest particle
neighbor_id = tree_ids[index]
# Get profile of closest particle
try:
r_initial = old_annealing_arrays[neighbor_id == old_ids][0]
except Exception:
warnings.warn("Warning: 2+ particles likely have the same ID",
stacklevel=2)
x = np.empty(r_length, dtype=dtype)
x.fill(dtype(np.inf))
age = np.nan
return (age, x)
# Get temp of closest particle
particle_start_temp = old_temps[neighbor_id == old_ids]
# If turned off, return np.nan
else:
x = np.empty(r_length, dtype=dtype)
x.fill(dtype(np.inf))
age = np.nan
return (age, x)
# Double checking that interpolated particle isn't bugged
if r_initial[0] == dtype(np.inf):
age = np.nan
return (age, r_initial)
# Getting average temperature
particle_temp = (particle_start_temp[0] + particle_end_temp[0]) / 2
# For basic annealing calculations, absolute time doesn't matter -
# we just need to maintain difference between start and end times
# (start time > end time because the function measures time in yrs BP)
r = aft.calc_annealing(r_initial, particle_temp, start=time_interval,
end=0, next_nan_index=k - 1,
constants=ft_constants[annealing_model])
# Only perform age calculations if necessary
if calc_age:
r_so_far = aft.dpar_conversion(r_mr=r[~np.isnan(r)],
dpar=dpar,
constants=ft_constants[annealing_model])
tsteps = np.arange(start=k * time_interval,
stop=-0.5 * time_interval,
step=-1 * time_interval)
age = aft.calc_aft_age(r_so_far, tsteps)
return (age, r)
else:
age = np.nan
return (age, r)
[docs]
def run_vtk(files, system, time_interval,
u=100, th=100, radius=50, num_nodes=513,
dpar=1.75, annealing_model='Ketcham99',
file_prefix='meshes_tchron', path='./',
temp_dir='~/dump',
batch_size=100, processes=None, interpolate_vals=True,
all_timesteps=True, overwrite=False):
"""Perform parallel He or FT forward modeling of ASPECT VTK data.
This code performs forward modeling of the AHe, ZHe, or AFT systems
across ASPECT VTK data. Data is output as .vtu folders in a new
directory, with data for every timestep given.
Parameters
----------
files : list of str
List of paths to VTK files to run forward model on. Files are
processed in the order they are given in the list.
system : str
Isotopic system to model. Current options include 'AHe': Apatite
(U-Th)/He, 'ZHe': Zircon (U-Th)/He, 'AFT': Apatite Fission Track
time_interval : float
Interval (Myrs) between times when each mesh was produced
u : float, optional
U concentration (ppm) (default: 100). Only used if system is 'AHe' or
'ZHe'.
th : float, optional
Th concentration (ppm) (default: 100). Only used if system is 'AHe' or
'ZHe'.
num_nodes : int
Number of nodes within the grain (for He) (default: 513). Unused for FT
system.
radius : float, optional
Radius of the grain (micrometers) (default: 50). Only used
if system is 'AHe' or 'ZHe'.
dpar : float, optional
Etch figure length (micrometers) (default: 1.75). Only used
if system is 'AFT'.
annealing_model : str, optional
Annealing model to use. Currently, the only acceptable value is
'Ketcham99', which corresponds to the fanning curvilinear model from
Ketcham et al. (1999) (default: 'Ketcham99'). Only used if system is
'AFT'.
file_prefix : str
Prefix to give output files (default: 'meshes_tchron')
path : str
Path to output directory (default: './')
temp_dir : str
Path to output directory used to temporarily dump data that does not
fit in memory.
batch_size : int or 'auto', optional
Number of jobs to be dispatched to each worker at a time during
parallel computation (default: 100). If set to 'auto', this value is
dynamically adjusted during computations to try to optimize efficiency.
However, on most test systems, better efficiency is gained by manually
setting batch size to 100 or 1000 than using 'auto'.
processes : int or None, optional
Maximum number of processes that can run concurrently. If None, this
parameter is internally set to two less than the number of CPUs on the
user's system (default: None).
interpolate_vals : bool
Boolean indicating whether to interpolate particle data from nearest
neighbor if the particle itself lacks He or FT data. If False and the
particle is missing data, an age of np.nan is returned for that
particle. (default: True)
all_timesteps : bool
Boolean indicating whether to calculate ages at each tstep
(default: True)
overwrite : bool
Boolean indicating whether to overwrite old meshes that already have
thermochronometric data for this system for a given timestep. If False,
this function skips timesteps that already have data for this system
and uses that data and uses for calculations in subsequent meshes.
(default: False)
"""
dtype = np.float32
# Setting variables for parallel computation
if processes is None:
processes = os.cpu_count() - 2
pre_dispatch = 2 * processes if batch_size == 'auto' else 2 * batch_size
particle_fn = {'AHe': run_particle_he,
'ZHe': run_particle_he,
'AFT': run_particle_ft}
# Setting up directory for temporary memory dumps
temp_dir = os.path.expanduser(temp_dir)
with suppress(FileNotFoundError):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir)
# Setting up output directory
output_dir = os.path.join(path, file_prefix)
os.makedirs(output_dir, exist_ok=True)
# Path for dump of cached internal values
cache_path = os.path.join(output_dir, 'cache_internal_vals.npy')
# internal_len represents the length of the "internal values" array
# for the input annealing system. For the (U-Th)/He system, this
# array is a profile of x values (He concentration times node position)
# across all nodes in a grain. For the AFT system, this array contains mean
# reduced lengths of fission tracks formed at each timestep.
# For AFT system, can use number of files to determine internal_len
# (-1 is to account for missing tstep from using averages)
# For He systems, can just use the input num_nodes
internal_len = len(files) - 1 if system == "AFT" else num_nodes
with Parallel(n_jobs=processes,
batch_size=batch_size,
pre_dispatch=pre_dispatch,
temp_folder=temp_dir) as parallel:
# Loop through timesteps
for k, file in enumerate(files):
# Setting up output file path
outfile_name = file_prefix + '_' + str(k).zfill(3) + '.vtu'
outfile_path = os.path.join(output_dir, outfile_name)
# Check if target mesh already exists
if os.path.exists(outfile_path):
original_mesh = pv.read(outfile_path)
# If the mesh only has data from other systems or we're willing
# to overwrite old data from the current thermochronologic
# system, set that mesh as our input/output mesh (to avoid
# overwriting other system data)
if overwrite or system not in original_mesh.point_data:
mesh = original_mesh
else:
# If this mesh exists and has data for the input
# thermochronologic system, we now need to see if this is
# the last mesh with data
next_filename = file_prefix + '_' + \
str(k + 1).zfill(3) + '.vtu'
next_filepath = os.path.join(output_dir, next_filename)
# If this mesh is the last mesh in the directory with
# data for the current system, load values from cache
next_mesh_exists = os.path.exists(next_filepath)
if (not next_mesh_exists) or \
(system not in pv.read(next_filepath).point_data):
ids = original_mesh['id']
positions = original_mesh.points
mesh_temps = original_mesh['T']
new_internal_vals = np.load(cache_path)
# Since the current mesh is complete, there's nothing left
# to do with it, so we can continue to the next mesh
continue
else:
# If no output mesh exists for this timestep, use the provided
# file as our mesh for this timestep
mesh = pv.read(file)
# Check if we're in the first timestep
if k == 0:
num_particles = len(mesh['T'])
# Set up empty arrays for first timestep
new_internal_vals = np.empty((num_particles, internal_len),
dtype=dtype)
new_internal_vals.fill(np.nan)
np.save(cache_path, new_internal_vals)
# Publish ages at 0
ages = np.zeros(num_particles)
mesh[system] = ages
mesh.save(outfile_path)
elif k > 0:
# If this is not the first timestep, take the "new" data
# from the previous timestep and rename it as the "old" data
old_internal_vals = new_internal_vals
old_ids = ids
old_positions = positions
old_temps = mesh_temps
# Memory management
gc.collect()
if k in np.arange(5, len(files), 5):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir)
# Extracting data from mesh
mesh_temps = mesh['T']
ids = mesh['id']
positions = mesh.points
# Run the forward model if we've seen at least 2 timesteps
if k > 0:
# Set up KDTree for timestep if doing interpolation
if interpolate_vals:
# Get particle ids of particles with internal_vals
has_vals = ~np.isnan(old_internal_vals).all(axis=1)
tree_ids = old_ids[has_vals]
# Get positions of other particles
other_positions = old_positions[has_vals]
# Note: At k=0, all particles have internal_vals but
# they're all set to NaN, so what we used above
# doesn't work for k=1 and we need a special case
if k == 1:
tree_ids = old_ids
other_positions = old_positions
# Set up KDTree to find closest particle
tree = KDTree(other_positions)
else:
tree = None
tree_ids = None
# Set up inputs for run_particle function
if system == 'AFT':
model_inputs = (dpar, annealing_model)
else:
model_inputs = (u, th, radius)
inputs = (k, positions, tree, ids, old_ids, tree_ids,
mesh_temps, old_temps, old_internal_vals,
time_interval, system, internal_len, model_inputs)
# Calculate ages if indicated or if on last timestep
calc_age = all_timesteps or k == len(files) - 1
prog_bar_txt = "Timestep " + str(k)
# Calculate new x profiles/annealed lengths and ages in parallel
output = parallel(
delayed(particle_fn[system])
(particle, inputs, calc_age, interpolate_vals)
for particle in tqdm(ids, position=0,
desc=prog_bar_txt, leave=False)
)
ages, new_internal_vals = zip(*output)
# Convert new_internal_vals to array and save for reload
new_internal_vals = np.array(new_internal_vals, dtype=dtype)
np.save(cache_path, new_internal_vals)
# Assign ages to mesh
mesh.point_data[system] = np.array(ages, dtype=dtype)
# Save new mesh
mesh.save(outfile_path)
# Purge the temp folder
with suppress(FileNotFoundError):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir)
# Print completion message
tqdm.write('All ' + system + ' timesteps complete')
# Delete cached values when all finished
os.remove(cache_path)
gc.collect()
return
[docs]
def run_tt_paths(temp_paths, tsteps, system,
u=100, th=100, radius=50,
dpar=1.75, annealing_model='Ketcham99',
batch_size=100, processes=None,
**kwargs):
"""Run forward model of a given isotopic system across multiple t-T paths.
Parameters
----------
temp_paths : list of NumPy arrays of floats
List of NumPy arrays of floats containing the temperatures (K) at each
timestep in tsteps. Each array corresponds to a different grain
to obtain a thermochronometric age for.
tsteps : Numpy array of floats
Array of times (Ma) in chronological (descending) order. First
time is start of first timestep, last time is end of last timestep.
Each pair of adjacent times composes a timestep. The time at a given
index i corresponds to the temperatures at index i of each of the NumPy
arrays in temp_paths.
system : str
Isotopic system to model. Current options include 'AHe': Apatite
(U-Th)/He, 'ZHe': Zircon (U-Th)/He, 'AFT': Apatite Fission Track
u : float, optional
U concentration (ppm) (default: 100). Only used if system is 'AHe'
or 'ZHe'.
th : float, optional
Th concentration (ppm) (default: 100). Only used if system is 'AHe'
or 'ZHe'.
radius : float, optional
Radius of the grain (micrometers) (default: 50). Only used
if system is 'AHe' or 'ZHe'.
dpar : float, optional
Etch figure length (micrometers) (default: 1.75). Only used
if system is 'AFT'.
annealing_model : str, optional
Annealing model to use. Currently, the only acceptable value is
'Ketcham99', which corresponds to the fanning curvilinear model from
Ketcham et al. (1999) (default: 'Ketcham99'). Only used if system is
'AFT'.
batch_size : int or 'auto', optional
Number of jobs to be dispatched to each worker at a time during
parallel computation (default: 100). If set to 'auto', this value is
dynamically adjusted during computations to try to optimize efficiency.
However, on most test systems, better efficiency is gained by manually
setting batch size to 100 or 1000 than using 'auto'.
processes : int or None, optional
Maximum number of processes that can run concurrently. If None, this
parameter is internally set to two less than the number of CPUs on the
user's system (default: None).
**kwargs : optional
Additional arguments to pass to the forward model function of the
corresponding isotopic system
Returns
-------
ages : list of floats
Thermochronometric ages for the given isotopic system for grains that
experienced each of the provided time series. All (U-Th)/He ages
returned are corrected ages.
"""
if processes is None:
processes = os.cpu_count() - 2
# Setting how many tasks to initially dispatch to workers
pre_dispatch = 2 * processes if batch_size == 'auto' else 2 * batch_size
model_fn = {'AHe': he.forward_model_he,
'ZHe': he.forward_model_he,
'AFT': aft.forward_model_aft}
ft_constants = {'Ketcham99': aft.KETCHAM_99_FC}
he_inputs = (tsteps, system, u, th, radius)
ft_inputs = (tsteps, dpar, ft_constants[annealing_model])
model_inputs = {'AHe': he_inputs,
'ZHe': he_inputs,
'AFT': ft_inputs}
output = Parallel(n_jobs=processes,
batch_size=batch_size,
pre_dispatch=pre_dispatch)(
delayed(model_fn[system])(path,
*model_inputs[system],
**kwargs)
for path in tqdm(temp_paths, position=0))
return output