Skip to content

Commit

Permalink
Plot gamma spec and fnu spec
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Jan 30, 2025
1 parent c2df88a commit 62232d5
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 103 deletions.
182 changes: 113 additions & 69 deletions artistools/spectra/plotspectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,14 @@ def plot_reference_spectrum(
z = metadata["z"]
bands = [(1.35e4, 1.44e4), (1.8e4, 1.94e4)] # [Angstroms]
bands_rest = [(band_low / (1 + z), band_high / (1 + z)) for band_low, band_high in bands]
# for band_low_rest, band_high_rest in bands:
# band_low = band_low_rest / (1 + z)
# band_high = band_high_rest / (1 + z)

specdata = specdata.with_columns(
f_lambda=pl.when(
pl.any_horizontal([
pl.col("lambda_angstroms").is_between(band_low_rest, band_high_rest, closed="both")
for band_low_rest, band_high_rest in bands_rest
])
)
.then(pl.lit(math.nan))
.otherwise(pl.col("f_lambda"))
expr_masked = pl.when(
pl.any_horizontal([
pl.col("lambda_angstroms").is_between(band_low_rest, band_high_rest, closed="both")
for band_low_rest, band_high_rest in bands_rest
])
)
specdata = specdata.with_columns(f_lambda=expr_masked.then(pl.lit(math.nan)).otherwise(pl.col("f_lambda")))

print(f"Reference spectrum '{plotkwargs['label']}' has {len(specdata)} points in the plot range")
print(f" file: {filename}")
Expand All @@ -194,15 +188,6 @@ def plot_reference_spectrum(

atspectra.print_integrated_flux(specdata["f_lambda"], specdata["lambda_angstroms"])

# if len(specdata) > 5000:
# # specdata = scipy.signal.resample(specdata, 10000)
# # specdata = specdata.iloc[::3, :].copy()
# print(f" downsampling to {len(specdata)} points")
# specdata = specdata.query("index % 3 == 0")

# clamp negative values to zero
# specdata['f_lambda'] = specdata['f_lambda'].apply(lambda x: max(0, x))

if fluxfilterfunc:
print(" applying filter to reference spectrum")
specdata = specdata.with_columns(cs.starts_with("f_lambda").map_batches(fluxfilterfunc))
Expand Down Expand Up @@ -255,9 +240,11 @@ def plot_artis_spectrum(
average_over_theta: bool = False,
usedegrees: bool = False,
maxpacketfiles: int | None = None,
xunit: str = "angstroms",
**plotkwargs,
) -> pl.DataFrame | None:
"""Plot an ARTIS output spectrum. The data plotted are also returned as a DataFrame."""
assert xunit.lower() in {"angstroms", "hz", "kev"}
modelpath = Path(modelpath)
if Path(modelpath).is_file(): # handle e.g. modelpath = 'modelpath/spec.out'
specfilename = Path(modelpath).parts[-1]
Expand Down Expand Up @@ -329,8 +316,8 @@ def plot_artis_spectrum(
if from_packets:
viewinganglespectra = atspectra.get_from_packets(
modelpath,
args.timemin,
args.timemax,
timelowdays=args.timemin,
timehighdays=args.timemax,
lambda_min=supxmin * 0.9,
lambda_max=supxmax * 1.1,
use_time=use_time,
Expand All @@ -341,6 +328,7 @@ def plot_artis_spectrum(
average_over_theta=average_over_theta,
fluxfilterfunc=filterfunc,
directionbins_are_vpkt_observers=args.plotvspecpol is not None,
gamma=args.gamma,
)

elif args.plotvspecpol is not None:
Expand Down Expand Up @@ -369,6 +357,7 @@ def plot_artis_spectrum(
average_over_phi=average_over_phi,
average_over_theta=average_over_theta,
fluxfilterfunc=filterfunc,
gamma=args.gamma,
)

dirbin_definitions = (
Expand Down Expand Up @@ -412,26 +401,38 @@ def plot_artis_spectrum(
if len(directionbins) > 1 or not linelabel_is_custom:
linelabel_withdirbin = f"{linelabel} {dirbin_definitions[dirbin]}"

atspectra.print_integrated_flux(dfspectrum["f_lambda"], dfspectrum["lambda_angstroms"])

if scale_to_peak:
dfspectrum = dfspectrum.with_columns(
f_lambda_scaled=pl.col("f_lambda") / pl.col("f_lambda").max() * scale_to_peak
)

ycolumnname = "f_lambda_scaled"
if args.x.lower() == "angstroms":
dfspectrum = dfspectrum.with_columns(x=pl.col("lambda_angstroms"), y=pl.col("f_lambda"))
elif args.x.lower() == "hz":
dfspectrum = dfspectrum.with_columns(x=pl.col("nu"), y=pl.col("f_nu"))
else:
ycolumnname = "f_lambda"
assert args.x.lower() == "kev"
h = 4.1356677e-15 # Planck's constant [eV s]
c = 2.99792458e18 # speed of light [angstroms/s]

dfspectrum = (
dfspectrum.with_columns(en_kev=h * c / pl.col("lambda_angstroms") / 1000.0)
.with_columns(f_en_kev=pl.col("f_nu") * pl.col("nu") / pl.col("en_kev"))
.with_columns(x=pl.col("en_kev"), y=pl.col("f_en_kev"))
)

if plotpacketcount:
ycolumnname = "packetcount"
dfspectrum = dfspectrum.with_columns(y=pl.col("packetcount"))

atspectra.print_integrated_flux(dfspectrum["y"], dfspectrum["x"])

if scale_to_peak:
dfspectrum = dfspectrum.with_columns(
y_scaled=pl.col("y") / pl.col("y").max() * scale_to_peak
).with_columns(y=pl.col("y_scaled"))

if args.binflux:
new_lambda_angstroms = []
binned_flux = []
assert args.x.lower() == "angstroms"

wavelengths = dfspectrum["lambda_angstroms"]
fluxes = dfspectrum[ycolumnname]
fluxes = dfspectrum["y"]
nbins = 5

for i in np.arange(0, len(wavelengths - nbins), nbins, dtype=int):
Expand All @@ -442,13 +443,12 @@ def plot_artis_spectrum(
sum_flux = sum(fluxes[j] for j in range(i, i_max))
binned_flux.append(sum_flux / ncontribs)

dfspectrum = pl.DataFrame({"lambda_angstroms": new_lambda_angstroms, ycolumnname: binned_flux})
dfspectrum = pl.DataFrame({"x": new_lambda_angstroms, "y": binned_flux})

dfspectrum = dfspectrum.sort("x", maintain_order=True)

axis.plot(
dfspectrum["lambda_angstroms"],
dfspectrum[ycolumnname],
label=linelabel_withdirbin if axindex == 0 else None,
**plotkwargs,
dfspectrum["x"], dfspectrum["y"], label=linelabel_withdirbin if axindex == 0 else None, **plotkwargs
)

return dfspectrum[["lambda_angstroms", "f_lambda"]]
Expand Down Expand Up @@ -499,6 +499,11 @@ def make_spectrum_plot(
if "linewidth" not in plotkwargs:
plotkwargs["linewidth"] = 1.1

if args.x.lower() != "angstroms":
# other units not implemented yet
msg = f"Unit {args.x} not implemented for reference spectra"
raise NotImplementedError(msg)

if args.multispecplot:
plotkwargs["color"] = "k"
supxmin, supxmax = axes[refspecindex].get_xlim()
Expand Down Expand Up @@ -558,6 +563,7 @@ def make_spectrum_plot(
average_over_phi=args.average_over_phi_angle,
average_over_theta=args.average_over_theta_angle,
usedegrees=args.usedegrees,
xunit=args.x,
**plotkwargs,
)
except FileNotFoundError as e:
Expand Down Expand Up @@ -607,13 +613,15 @@ def make_spectrum_plot(
# zorder=-1,
# )

if args.stokesparam == "I":
if args.stokesparam == "I" and not args.logscaley:
axis.set_ylim(bottom=0.0)
if args.normalised:
axis.set_ylim(top=1.25)
axis.set_ylabel(r"Scaled F$_\lambda$")

if args.plotpacketcount:
axis.set_ylabel(r"Monte Carlo packets per bin")
elif args.normalised:
axis.set_ylim(top=1.25)
axis.set_ylabel(r"Scaled F$_\lambda$")

if not args.notitle and args.title:
if args.inset_title:
axis.annotate(
Expand All @@ -639,6 +647,7 @@ def make_emissionabsorption_plot(
) -> tuple[list[Artist], list[str], pl.DataFrame | None]:
"""Plot the emission and absorption contribution spectra, grouped by ion/line/term for an ARTIS model."""
modelname = get_model_name(modelpath)
assert args.x.lower() == "angstroms"

print(f"====> {modelname}")
clamp_to_timesteps = not args.notimeclamp
Expand Down Expand Up @@ -1008,33 +1017,48 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
dfalldata: pl.DataFrame | None = pl.DataFrame()

if not args.hideyticklabels:
ylabel = None
if args.x.lower() == "angstroms":
ylabel = r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]"
elif args.x.lower() == "hz":
ylabel = r"F$_\nu$ at 1 Mpc [{}erg/s/cm$^2$/Hz]"
elif args.x.lower() == "kev":
ylabel = r"F$_en_kev$ at 1 Mpc [{}erg/s/cm$^2$/keV]"

assert ylabel is not None
if args.logscaley:
# don't include the {} that will be replaced with the power of 10 by the custom formatter
ylabel = ylabel.replace("{}", "")

if args.multispecplot:
for ax in axes:
ax.set_ylabel(r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]")

elif args.logscale:
# don't include the {} that will be replaced with the power of 10 by the custom formatter
axes[-1].set_ylabel(r"F$_\lambda$ at 1 Mpc [erg/s/cm$^2$/$\mathrm{{\AA}}$]")
ax.set_ylabel(ylabel)
else:
axes[-1].set_ylabel(r"F$_\lambda$ at 1 Mpc [{}erg/s/cm$^2$/$\mathrm{{\AA}}$]")
axes[-1].set_ylabel(ylabel)

for axis in axes:
if args.logscale:
if args.xmin is not None:
axis.set_xlim(left=args.xmin)
if args.xmax is not None:
axis.set_xlim(right=args.xmax)
if args.logscalex:
axis.set_xscale("log")
if args.logscaley:
axis.set_yscale("log")
axis.set_xlim(left=args.xmin, right=args.xmax)

if (args.xmax - args.xmin) < 2000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=100))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
elif (args.xmax - args.xmin) < 11000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=1000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=100))
elif (args.xmax - args.xmin) < 14000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))
else:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))

if not args.gamma:
if (args.xmax - args.xmin) < 2000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=100))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=10))
elif (args.xmax - args.xmin) < 11000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=1000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=100))
elif (args.xmax - args.xmin) < 14000:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))
else:
axis.xaxis.set_major_locator(ticker.MultipleLocator(base=2000))
axis.xaxis.set_minor_locator(ticker.MultipleLocator(base=500))

if densityplotyvars:
make_contrib_plot(axes[:-1], args.specpath[0], densityplotyvars, args)
Expand Down Expand Up @@ -1100,7 +1124,7 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
for index, ax in enumerate(axes):
# ax.xaxis.set_major_formatter(plt.NullFormatter())

if "{" in ax.get_ylabel() and not args.logscale:
if "{" in ax.get_ylabel() and not args.logscaley:
ax.yaxis.set_major_formatter(ExponentLabelFormatter(ax.get_ylabel(), decimalplaces=1))

if args.hidexticklabels:
Expand All @@ -1114,7 +1138,13 @@ def make_plot(args) -> tuple[mplfig.Figure, np.ndarray, pl.DataFrame]:
ax.text(5500, ymax * 0.9, f"{args.timedayslist[index]} days") # multispecplot text

if not args.hidexticklabels:
axes[-1].set_xlabel(r"Wavelength $\left[\mathrm{{\AA}}\right]$")
if args.x.lower() == "angstroms":
axes[-1].set_xlabel(r"Wavelength $\left[\mathrm{{\AA}}\right]$")
elif args.x.lower() == "kev":
axes[-1].set_xlabel(r"Energy $\left[\mathrm{{keV}}\right]$")
else:
msg = f"Unknown x-axis unit {args.x}"
raise AssertionError(msg)

if not args.outputfile:
args.outputfile = defaultoutputfile
Expand Down Expand Up @@ -1164,6 +1194,8 @@ def addargs(parser) -> None:

parser.add_argument("-dashes", default=[], nargs="*", help="Dashes property of lines")

parser.add_argument("--gamma", action="store_true", help="Make light curve from gamma rays instead of R-packets")

parser.add_argument("--greyscale", action="store_true", help="Plot in greyscale")

parser.add_argument(
Expand Down Expand Up @@ -1232,12 +1264,14 @@ def addargs(parser) -> None:
"--notimeclamp", action="store_true", help="When plotting from packets, don't clamp to timestep start/end"
)

parser.add_argument("-x", default=None, choices=["angstroms", "kev", "hz"], help="x (horizontal) axis unit")

parser.add_argument(
"-xmin", "-lambdamin", dest="xmin", type=int, default=2500, help="Plot range: minimum wavelength in Angstroms"
"-xmin", "-lambdamin", dest="xmin", type=float, default=None, help="Plot range: minimum x range"
)

parser.add_argument(
"-xmax", "-lambdamax", dest="xmax", type=int, default=19000, help="Plot range: maximum wavelength in Angstroms"
"-xmax", "-lambdamax", dest="xmax", type=float, default=None, help="Plot range: maximum x range"
)

parser.add_argument(
Expand Down Expand Up @@ -1304,7 +1338,9 @@ def addargs(parser) -> None:
"-figscale", type=float, default=1.8, help="Scale factor for plot area. 1.0 is for single-column"
)

parser.add_argument("--logscale", action="store_true", help="Use log scale")
parser.add_argument("--logscalex", action="store_true", help="Use log scale for x values")

parser.add_argument("--logscaley", action="store_true", help="Use log scale for y values")

parser.add_argument("--hidenetspectrum", action="store_true", help="Hide net spectrum")

Expand Down Expand Up @@ -1426,6 +1462,14 @@ def main(args: argparse.Namespace | None = None, argsraw: Sequence[str] | None =
print("WARNING: --average_every_tenth_viewing_angle is deprecated. use --average_over_phi_angle instead")
args.average_over_phi_angle = True

if args.x is None:
args.x = "kev" if args.gamma else "angstroms"

if args.xmin is None and not args.gamma:
args.xmin = 2500
if args.xmax is None and not args.gamma:
args.xmax = 19000

set_mpl_style()

assert (
Expand Down
Loading

0 comments on commit 62232d5

Please sign in to comment.