Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linearisation #12

Merged
merged 4 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions examples/01_anelasticity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ def build_solidus():
my_depths.extend(dpths)
my_solidus.extend(solidus_model.at_depth(dpths))

# Avoding unnecessary extrapolation by setting the solidus temperature at maximum depth
my_depths.extend([3000e3])
my_solidus.extend([solidus_model.at_depth(dpths[-1])])

# Since we might have values outside the range of the solidus curve, we are better off with extrapolating
ghelichkhan_et_al = SplineProfile(
depth=np.asarray(my_depths),
value=np.asarray(my_solidus),
extrapolate=True,
name="Ghelichkhan et al 2021")

return ghelichkhan_et_al
Expand Down Expand Up @@ -111,22 +109,46 @@ def Q_kappa(x):
anelasticity = build_anelasticity_model(solidus_ghelichkhan, q_profile=cammarano_q_model)
anelastic_slb_pyrolite = gdrift.apply_anelastic_correction(
slb_pyrolite, anelasticity)

pyrolite_anelastic_s_speed = anelastic_slb_pyrolite.compute_swave_speed()
pyrolite_anelastic_p_speed = anelastic_slb_pyrolite.compute_pwave_speed()

# A temperautre profile representing the mantle average temperature
# This is used to anchor the regularised thermodynamic table (we make sure the seismic speeds are the same at those temperature for the regularised and unregularised table)
temperature_spline = gdrift.SplineProfile(
depth=np.asarray([0., 500e3, 2700e3, 3000e3]),
value=np.asarray([300, 1000, 3000, 4000])
)


linear_slb_pyrolite = gdrift.mineralogy.regularise_thermodynamic_table(
slb_pyrolite, temperature_spline,
regular_range={"v_s": [-1.0, 0], "v_p": [-1.0, 0.], "rho": [-0.5, 0.]}
)

# Regularising the table
linear_anelastic_slb_pyrolite = gdrift.apply_anelastic_correction(
linear_slb_pyrolite, anelasticity
)

# linearised seismic speeds
linear_pyrolite_anelastic_s_speed = linear_anelastic_slb_pyrolite.compute_swave_speed()
linear_pyrolite_anelastic_p_speed = linear_anelastic_slb_pyrolite.compute_pwave_speed()

# contour lines to plot
cntr_lines = np.linspace(4000, 7000, 20)

plt.close("all")
fig, axes = plt.subplots(ncols=2)
axes[0].set_position([0.1, 0.1, 0.35, 0.8])
axes[1].set_position([0.5, 0.1, 0.35, 0.8])
fig, axes = plt.subplots(figsize=(12, 8), ncols=3)
axes[0].set_position([0.05, 0.1, 0.25, 0.8])
axes[1].set_position([0.35, 0.1, 0.25, 0.8])
axes[2].set_position([0.65, 0.1, 0.25, 0.8])
# Getting the coordinates
depths_x, temperatures_x = np.meshgrid(
slb_pyrolite.get_depths(), slb_pyrolite.get_temperatures(), indexing="ij")
img = []

for id, table in enumerate([pyrolite_elastic_s_speed, pyrolite_anelastic_s_speed]):
for id, table in enumerate([pyrolite_elastic_s_speed, pyrolite_anelastic_s_speed, linear_pyrolite_anelastic_s_speed]):
img.append(axes[id].contourf(
temperatures_x,
depths_x,
Expand All @@ -148,6 +170,10 @@ def Q_kappa(x):
axes[1].text(0.5, 1.05, s="With Anelastic Correction",
ha="center", va="center",
transform=axes[1].transAxes, bbox=dict(facecolor=(1.0, 1.0, 0.7)))
axes[1].text(0.5, 1.05, s="Linearised With Anelastic Correction",
ha="center", va="center",
transform=axes[1].transAxes, bbox=dict(facecolor=(1.0, 1.0, 0.7)))

fig.colorbar(img[-1], ax=axes[0], cax=fig.add_axes([0.88,
0.1, 0.02, 0.8]), orientation="vertical", label="Shear-Wave Speed [m/s]")

Expand All @@ -157,9 +183,11 @@ def Q_kappa(x):
plt.close(2)
fig_2 = plt.figure(num=2)
ax_2 = fig_2.add_subplot(111)
index = 100
index = 130
ax_2.plot(pyrolite_anelastic_s_speed.get_y(),
pyrolite_anelastic_s_speed.get_vals()[index, :], color="blue", label="With Anelastic Correction")
ax_2.plot(pyrolite_anelastic_s_speed.get_y(),
linear_pyrolite_anelastic_s_speed.get_vals()[index, :], color="green", label="Linear Anelastic Model")
ax_2.plot(pyrolite_anelastic_s_speed.get_y(),
pyrolite_elastic_s_speed.get_vals()[index, :], color="red", label="Elastic Model")
ax_2.vlines(
Expand All @@ -184,9 +212,10 @@ def Q_kappa(x):
plt.close(3)
fig_3 = plt.figure(num=3)
ax_3 = fig_3.add_subplot(111)
index = 100
ax_3.plot(pyrolite_anelastic_p_speed.get_y(),
pyrolite_anelastic_p_speed.get_vals()[index, :], color="blue", label="With Anelastic Correction")
ax_3.plot(pyrolite_anelastic_p_speed.get_y(),
linear_pyrolite_anelastic_p_speed.get_vals()[index, :], color="green", label="Linear Anelastic Model")
ax_3.plot(pyrolite_anelastic_p_speed.get_y(),
pyrolite_elastic_p_speed.get_vals()[index, :], color="red", label="Elastic Model")
ax_3.vlines(
Expand Down
90 changes: 90 additions & 0 deletions examples/linearisation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import gdrift
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# Tutorial: Linearising and Regularising Thermodynamic Tables
#
# This tutorial demonstrates how to regularise thermodynamic properties of
# Earth's mantle using `gdrift`, with a focus on linearising wave speeds
# (S-wave and P-wave) and density (ρ). The process involves building a
# regularised thermodynamic model, comparing original and regularised tables,
# and visualising the results at specific depths.
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# 1. Initial Setup
# Define the thermodynamic model for pyrolite composition.
# -----------------------------------------------------------------------------
slb_pyrolite = gdrift.ThermodynamicModel(
"SLB_16", "pyrolite", temps=np.linspace(300, 4000), depths=np.linspace(0, 2890e3)
)

# -----------------------------------------------------------------------------
# 2. Building a Temperature Profile
# Create a `SplineProfile` to represent a depth-dependent temperature profile.
# This profile serves as a baseline for regularising the thermodynamic tables.
# -----------------------------------------------------------------------------
temperature_spline = gdrift.SplineProfile(
depth=np.asarray([0., 500e3, 2700e3, 3000e3]),
value=np.asarray([300, 1000, 3000, 4000])
)

# -----------------------------------------------------------------------------
# 3. Regularising the Thermodynamic Table
# Use the `regularise_thermodynamic_table` function to produce a corrected
# version of the model, ensuring the properties align with the temperature
# profile.
# -----------------------------------------------------------------------------
regular_slb_pyrolite = gdrift.regularise_thermodynamic_table(slb_pyrolite, temperature_spline)

# -----------------------------------------------------------------------------
# 4. Extracting Data
# Extract the original and regularised tables for S-wave speed, P-wave speed,
# and density.
# -----------------------------------------------------------------------------
Vs_original = slb_pyrolite.compute_swave_speed().get_vals()
Vp_original = slb_pyrolite.compute_pwave_speed().get_vals()
rho_original = slb_pyrolite._tables["rho"].get_vals()

Vs_corrected = regular_slb_pyrolite.compute_swave_speed().get_vals()
Vp_corrected = regular_slb_pyrolite.compute_pwave_speed().get_vals()
rho_corrected = regular_slb_pyrolite._tables["rho"].get_vals()

# -----------------------------------------------------------------------------
# 5. Visualising the Results
# Visualise the S-wave speed, P-wave speed, and density for both the original
# and regularised models at specific depths. Results are presented in three
# columns, one for each property.
# -----------------------------------------------------------------------------
depths = np.asarray([410, 660, 1000, 2000]) * 1e3
indices = [abs(d - slb_pyrolite.get_depths()).argmin() for d in depths]

# Create figure with 3 columns for Vs, Vp, and rho
fig, axs = plt.subplots(len(indices), 3, figsize=(15, 10), constrained_layout=True)

# Labels for the columns
column_titles = [r"$v_s$", r"$v_p$", r"$\rho$"]

# Plotting data for each depth
for i, idx in enumerate(indices):
depth = slb_pyrolite.get_depths()[idx]
for j, (original, corrected, label) in enumerate(
zip(
[Vs_original, Vp_original, rho_original],
[Vs_corrected, Vp_corrected, rho_corrected],
column_titles,
)
):
axs[i, j].plot(slb_pyrolite.get_temperatures(), original[idx, :], color="blue", label="Original")
axs[i, j].plot(slb_pyrolite.get_temperatures(), corrected[idx, :], color="red", label="Regularised")
axs[i, j].axvline(
x=temperature_spline.at_depth(depth), color="green", linestyle="--", label="Temperature Anchor"
)
axs[i, j].set_title(f"{label} at Depth: {depth / 1e3:.0f} km")
axs[i, j].set_xlabel("Temperature (K)")
axs[i, j].grid()
if i == 0:
axs[i, j].legend()

plt.show()
6 changes: 4 additions & 2 deletions gdrift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from .datasetnames import print_datasets_markdown
from .earthmodel3d import EarthModel3D
from .io import load_dataset, create_dataset_file
from .mineralogy import ThermodynamicModel, compute_pwave_speed, compute_swave_speed
from .profile import PreliminaryRefEarthModel, RadialEarthModelFromFile, HirschmannSolidus
from .mineralogy import ThermodynamicModel, compute_pwave_speed, compute_swave_speed, regularise_thermodynamic_table
from .profile import PreliminaryRefEarthModel, RadialEarthModelFromFile, HirschmannSolidus, SplineProfile
from .utility import compute_gravity, compute_mass, compute_pressure, geodetic_to_cartesian, cartesian_to_geodetic, dimensionalise_coords, nondimensionalise_coords, fibonacci_sphere
from .seismic import SeismicModel, AVAILABLE_SEISMIC_MODELS

Expand All @@ -20,9 +20,11 @@
"ThermodynamicModel",
"compute_pwave_speed",
"compute_swave_speed",
"regularise_thermodynamic_table",
"PreliminaryRefEarthModel",
"RadialEarthModelFromFile",
"HirschmannSolidus",
"SplineProfile",
"compute_gravity",
"compute_mass",
"compute_pressure",
Expand Down
144 changes: 141 additions & 3 deletions gdrift/mineralogy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import numpy
from .profile import AbstractProfile
from .io import load_dataset
from scipy.interpolate import RectBivariateSpline
from scipy.optimize import minimize_scalar
from scipy.spatial import cKDTree
from numbers import Number
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, Dict
import numpy as np

# Default regular range for gradients
# This will be used in regularise_thermodynamic_table
# if nothing is provided
default_regular_range = {
"v_s": (-np.inf, 0.0),
"v_p": (-np.inf, 0.0),
"rho": (-np.inf, 0.0),
}


MODELS_AVAIL = ['SLB_16', "SLB_21"]
COMPOSITIONS_AVAIL = ['pyrolite', 'basalt']
Expand Down Expand Up @@ -154,7 +167,7 @@ def temperature_to_rho(self, temperature, depth):
return LinearRectBivariateSpline(
self._tables["rho"].get_x(),
self._tables["rho"].get_y(),
self._tables["rho"].get_vals).ev(depth, temperature)
self._tables["rho"].get_vals()).ev(depth, temperature)

def compute_swave_speed(self):
return type(self._tables["shear_mod"])(
Expand Down Expand Up @@ -285,7 +298,7 @@ def compute_swave_speed(shear_modulus, density):
return numpy.sqrt(numpy.divide(shear_modulus, density))


def compute_pwave_speed(bulk_modulus, shear_modulus, density):
def compute_pwave_speed(bulk_modulus: Number, shear_modulus: Number, density: Number) -> Number:
"""Calculate the P-wave (primary wave) speed in a material based on its bulk modulus,
shear modulus, and density. Inputs can be floats or numpy arrays of the same size.

Expand All @@ -309,6 +322,7 @@ def compute_pwave_speed(bulk_modulus, shear_modulus, density):
"""
# making sure that input is either array or float
is_either_float_or_array(bulk_modulus, shear_modulus, density)

return numpy.sqrt(
numpy.divide(
bulk_modulus + (4. / 3.) * shear_modulus,
Expand All @@ -324,3 +338,127 @@ def is_either_float_or_array(*args):
if any(isinstance(x, numpy.ndarray) for x in args) and not all(isinstance(x, float) for x in args):
if not all(x.shape == args[0].shape for x in args if isinstance(x, numpy.ndarray)):
raise ValueError("All input arrays must have the same size.")


def derive_then_integrate(table: Table, temperature_profile: AbstractProfile, regular_range: Dict[str, Tuple]) -> np.ndarray:
"""
Derives the temperature gradient, interpolates irregular values, and integrates again to obtain velocity.
The output is anchored (= 0.) at around velocity values that are associated at temperature_profile.
Args:
table (object): An object containing depth and temperature data with methods `get_x()`, `get_y()`, and `get_vals()`.
temperature_profile (object): An object with a method `at_depth(depths)` that returns temperature values at given depths.
regular_range (dict): A dictionary with keys corresponding to table names and values as tuples indicating the acceptable range for gradients.
Returns:
np.ndarray: A 2D array representing the integrated velocity values adjusted for the temperature profile.
"""

# Getting the name of the table
key = table._name
# Getting the depths and temperatures
depths = table.get_x()
temperatures = table.get_y()

# temperature gradient
dT = np.gradient(temperatures)

# Creating a mesh for the depths and temperatures
depths_x, temperatures_x = np.meshgrid(depths, temperatures, indexing="ij")

# Getting the gradients
dV_dT = np.gradient(table.get_vals(), depths, temperatures, axis=(0, 1))[1]

# Finding the regular range of values (No positive jumps, no high negative jumps)
within_range = np.logical_and(dV_dT < regular_range[key][1], dV_dT > regular_range[key][0])

# building a tree out of the regular values
my_tree = cKDTree(np.column_stack((depths_x[within_range].flatten(), temperatures_x[within_range].flatten())))

# Finding the closest values to the irregular values
distances, inds = my_tree.query(np.column_stack((depths_x[~ within_range].flatten(), temperatures_x[~ within_range].flatten())), k=3)

# Interpolating the irregular values
dV_dT[~within_range] = np.sum(1 / distances * dV_dT[within_range].flatten()[inds], axis=1) / np.sum(1 / distances, axis=1)

# Integrating the derivate again to get the velocity (note that a constant needs to be found)
V = np.cumsum(dV_dT * dT, axis=1)
# One D profile of vs that best describes the temperature profile
t_mean_array = np.asarray([V[i, j] for i, j in enumerate(abs(temperature_profile.at_depth(depths_x) - temperatures).argmin(axis=1))])

# Broadcasting to the correct shape
t_mean_array_x, _ = np.meshgrid(t_mean_array, temperatures, indexing="ij")

# Anchoring the V-T curve at each depth for acnhor T to have zero velocity
V -= t_mean_array_x

return V


def regularise_thermodynamic_table(slb_pyrolite: ThermodynamicModel, temperature_profile: AbstractProfile, regular_range: Dict[str, Tuple] = default_regular_range):
"""
Regularises the thermodynamic table by creating a regularised thermodynamic model that uses precomputed
regular tables for S-wave and P-wave speeds.

Args:
slb_pyrolite (ThermodynamicModel): The original thermodynamic model.
temperature_profile (AbstractProfile): The temperature profile to be used for regularisation. This is supposed to
be a 1D profile of average temperature profiles.
regular_range (Dict[str, Tuple], optional): Dictionary specifying the regularisation range for each
parameter. Defaults to `default_regular_range`.

Returns:
RegularisedThermodynamicModel: A regularised thermodynamic model with precomputed tables for S-wave
and P-wave speeds.
"""
class RegularisedThermodynamicModel(ThermodynamicModel):
"""
A wrapper class for a regularised thermodynamic model that uses precomputed regular tables
for S-wave and P-wave speed instead of the default methods.
"""

def __init__(self, original_model, regular_tables):
"""
Initialize the regularised model.

Args:
original_model (ThermodynamicModel): The original thermodynamic model (e.g., slb_pyrolite).
regular_tables (Dict[str, np.ndarray]): Dictionary containing regularised tables for 'v_s' and 'v_p'.
"""
# Inherit properties from the original model
super().__init__(original_model.model, original_model.composition,
temps=original_model.get_temperatures(),
depths=original_model.get_depths())
self.regular_tables = regular_tables
self._tables["rho"] = Table(self.get_temperatures(), self.get_depths(), self.regular_tables["rho"], name="Regularised Density")

def compute_swave_speed(self):
"""
Returns the regularised S-wave speed as a `Table` object.
"""
return Table(self.get_temperatures(), self.get_depths(), self.regular_tables["v_s"], name="Regularised S-wave Speed")

def compute_pwave_speed(self):
"""
Returns the regularised P-wave speed as a `Table` object.
"""
return Table(self.get_temperatures(), self.get_depths(), self.regular_tables["v_p"], name="Regularised P-wave Speed")

# All the combinations of depths

# regular tables are a dictaionary of tables
regular_tables = {}

# iterating over the tables
for table, convert_T2V in zip([slb_pyrolite._tables["rho"], slb_pyrolite.compute_swave_speed(), slb_pyrolite.compute_pwave_speed()],
[slb_pyrolite.temperature_to_rho, slb_pyrolite.temperature_to_vs, slb_pyrolite.temperature_to_vp]):
# Get name for the table
key = table._name

regular_tables[key] = derive_then_integrate(table, temperature_profile, regular_range)

# the velocity for the given temperature profile
v_average = convert_T2V(temperature=temperature_profile.at_depth(table.get_x()), depth=table.get_x())

# Subtracting the mean
regular_tables[key] += v_average[:, None]

return RegularisedThermodynamicModel(slb_pyrolite, regular_tables)
Loading