This notebook demonstrates how to use a RAIL Creator to calculate true posteriors for galaxy samples drawn from the same Creator. Note that this notebook assumes you have already read through degradation-demo.ipynb
.
Calculating posteriors is more complicated than drawing samples, because it requires more knowledge of the engine that underlies the Creator. In this example, we will use the same engine we used in Degradation demo: FlowEngine
which wraps a normalizing flow from the pzflow package.
This notebook will cover three scenarios of increasing complexity:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pzflow.examples import get_example_flow
from rail.creation import Creator, engines
from rail.creation.degradation import (
InvRedshiftIncompleteness,
LineConfusion,
LSSTErrorModel,
QuantityCut,
)
Found classifier FZBoost Found classifier randomPZ Found classifier simpleNN Found classifier trainZ Found classifier BPZ_lite
For a basic first example, let's make a Creator with no degradation and draw a sample.
# create the FlowEngine
flowEngine = engines.FlowEngine(get_example_flow())
# create the Creator for the true distribution
creator_truth = Creator(flowEngine)
# draw a few samples
samples_truth = creator_truth.sample(6, seed=0)
Now, let's calculate true posteriors for this sample. Note the important fact here: these are literally the true posteriors for the sample because pzflow gives us direct access to the probability distribution from which the sample was drawn!
When calculating posteriors, the Creator will always require data
, which is a pandas DataFrame of the galaxies for which we are calculating posteriors. Because we are using a Creator built on FlowEngine
, we also must provide grid
, because FlowEngine
calculates posteriors over a grid of redshift values.
Let's calculate posteriors for every galaxy in our sample:
grid = np.linspace(0, 2.5, 100)
pdfs = creator_truth.get_posterior(samples_truth, grid=grid)
Note that Creator returns the pdfs as a qp Ensemble:
pdfs
<qp.ensemble.Ensemble at 0x7fa076fcdeb0>
Let's plot these pdfs:
fig, axes = plt.subplots(2, 3, constrained_layout=True, dpi=120)
for i, ax in enumerate(axes.flatten()):
# plot the pdf
pdfs[i].plot_native(axes=ax)
# plot the true redshift
ax.axvline(samples_truth["redshift"][i], c="k", ls="--")
# remove x-ticks on top row
if i < 3:
ax.set(xticks=[])
# set x-label on bottom row
else:
ax.set(xlabel="redshift")
# set y-label on far left column
if i % 3 == 0:
ax.set(ylabel="p(z)")
The true posteriors are in blue, and the true redshifts are marked by the vertical black lines.
Now, let's get a little more sophisticated.
Let's recreate the Creator we were using at the end of the Degradation demo.
I will make one change however: the LSST Error Model sometimes results in non-detections for faint galaxies. These non-detections are flagged with NaN. Calculating posteriors for galaxies with missing magnitudes is more complicated, so for now, I will add one additional QuantityCut to remove any galaxies with missing magnitudes. To see how to calculate posteriors for galaxies with missing magnitudes, see Section 3.
# set up the error model
errorModel = LSSTErrorModel()
def degrader_cut_nondetects(data: pd.DataFrame, seed: int = None) -> pd.DataFrame:
# apply the error model
data = errorModel(data, seed)
# make a cut on the observed i band
data = QuantityCut({"i": 25.3})(data, seed)
# introduce redshift incompleteness
data = InvRedshiftIncompleteness(0.8)(data, seed)
# introduce spectroscopic errors
# Oxygen lines (in angstroms)
OII = 3727
OIII = 5007
# 2% OII -> OIII confusion
data = LineConfusion(true_wavelen=OII, wrong_wavelen=OIII, frac_wrong=0.02)(data, seed)
# 1% OIII -> OII confusion
data = LineConfusion(true_wavelen=OIII, wrong_wavelen=OII, frac_wrong=0.01)(data, seed)
# remove all galaxies with non-detections
# note that since non-detections are flagged with np.nan, it is enough
# to specify infinity as the maximum magnitude, since np.nan < np.inf
# evaluates to false
data = QuantityCut({band: np.inf for band in "ugrizy"})(data, seed)
return data
creator_degraded_wo_nondetects = Creator(flowEngine, degrader=degrader_cut_nondetects)
Now let's draw a degraded sample:
samples_degraded_wo_nondetects = creator_degraded_wo_nondetects.sample(6, seed=1)
samples_degraded_wo_nondetects
redshift | u | u_err | g | g_err | r | r_err | i | i_err | z | z_err | y | y_err | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1.064106 | 27.407302 | 0.578927 | 26.489333 | 0.099019 | 25.751852 | 0.050750 | 25.231646 | 0.047423 | 24.549443 | 0.045670 | 24.306260 | 0.083875 |
1 | 0.482760 | 26.595178 | 0.312716 | 25.912547 | 0.059545 | 25.013940 | 0.026455 | 24.684454 | 0.029236 | 24.469733 | 0.042553 | 24.367206 | 0.088499 |
2 | 0.780574 | 27.155707 | 0.482011 | 26.221823 | 0.078271 | 25.620341 | 0.045157 | 24.835177 | 0.033380 | 24.626450 | 0.048901 | 24.355207 | 0.087569 |
3 | 0.892465 | 25.988383 | 0.189967 | 25.830684 | 0.055380 | 25.049733 | 0.027294 | 24.289026 | 0.020761 | 23.760510 | 0.022844 | 23.647692 | 0.046803 |
4 | 0.578259 | 24.181043 | 0.039464 | 23.452045 | 0.008251 | 22.653687 | 0.005911 | 22.135608 | 0.005809 | 21.957911 | 0.006685 | 21.826727 | 0.010308 |
5 | 1.888028 | 27.148641 | 0.479486 | 26.140485 | 0.072849 | 25.566301 | 0.043043 | 25.039548 | 0.039992 | 24.422766 | 0.040817 | 24.238597 | 0.079016 |
This sample has photometric errors that we would like to convolve in the redshift posteriors, so that the posteriors are fully consistent with the errors. We can perform this convolution by sampling from the error distributions, calculating posteriors, and averaging.
FlowEngine
has this functionality already built in - we just have to provide err_samples
to the get_posterior
method.
Let's calculate posteriors with a variable number of error samples.
pdfs_errs_convolved = {
err_samples:
creator_degraded_wo_nondetects.get_posterior(
samples_degraded_wo_nondetects,
grid=grid,
err_samples=err_samples,
seed=0,
batch_size=2,
)
for err_samples in [1, 10, 100, 1000]
}
fig, axes = plt.subplots(2, 3, dpi=120)
for i, ax in enumerate(axes.flatten()):
# set dummy values for xlim
xlim = [np.inf, -np.inf]
for pdfs_ in pdfs_errs_convolved.values():
# plot the pdf
pdfs_[i].plot_native(axes=ax)
# get the x value where the pdf first rises above 2
xmin = grid[np.argmax(pdfs_[i].pdf(grid)[0] > 2)]
if xmin < xlim[0]:
xlim[0] = xmin
# get the x value where the pdf finally falls below 2
xmax = grid[-np.argmax(pdfs_[i].pdf(grid)[0, ::-1] > 2)]
if xmax > xlim[1]:
xlim[1] = xmax
# plot the true redshift
z_true = samples_degraded_wo_nondetects["redshift"][i]
ax.axvline(z_true, c="k", ls="--")
# set x-label on bottom row
if i >= 3:
ax.set(xlabel="redshift")
# set y-label on far left column
if i % 3 == 0:
ax.set(ylabel="p(z)")
# set the x-limits so we can see more detail
xlim[0] -= 0.2
xlim[1] += 0.2
ax.set(xlim=xlim, yticks=[])
# create the legend
axes[0, 1].plot([], [], c="C0", label=f"1 sample")
for i, n in enumerate([10, 100, 1000]):
axes[0, 1].plot([], [], c=f"C{i+1}", label=f"{n} samples")
axes[0, 1].legend(
bbox_to_anchor=(0.5, 1.3),
loc="upper center",
ncol=4,
)
plt.show()
WARNING:matplotlib.font_manager:findfont: Font family ['serif'] not found. Falling back to DejaVu Sans.
You can see the effect of convolving the errors. In particular, notice that without error convolution (1 sample), the redshift posterior is often totally inconsistent with the true redshift (marked by the vertical black line). As you convolve more samples, the posterior generally broadens and becomes consistent with the true redshift.
Also notice how the posterior continues to change as you convolve more and more samples. This suggests that you need to do a little testing to ensure that you have convolved enough samples.
Let's plot these same posteriors with even more samples to make sure they have converged:
WARNING: Running the next cell on your computer may exhaust your memory
pdfs_errs_convolved_more_samples = {
err_samples:
creator_degraded_wo_nondetects.get_posterior(
samples_degraded_wo_nondetects,
grid=grid,
err_samples=err_samples,
seed=0,
batch_size=2
)
for err_samples in [1000, 2000, 5000, 10000]
}
fig, axes = plt.subplots(2, 3, dpi=120)
for i, ax in enumerate(axes.flatten()):
# set dummy values for xlim
xlim = [np.inf, -np.inf]
for pdfs_ in pdfs_errs_convolved_more_samples.values():
# plot the pdf
pdfs_[i].plot_native(axes=ax)
# get the x value where the pdf first rises above 2
xmin = grid[np.argmax(pdfs_[i].pdf(grid)[0] > 2)]
if xmin < xlim[0]:
xlim[0] = xmin
# get the x value where the pdf finally falls below 2
xmax = grid[-np.argmax(pdfs_[i].pdf(grid)[0, ::-1] > 2)]
if xmax > xlim[1]:
xlim[1] = xmax
# plot the true redshift
z_true = samples_degraded_wo_nondetects["redshift"][i]
ax.axvline(z_true, c="k", ls="--")
# set x-label on bottom row
if i >= 3:
ax.set(xlabel="redshift")
# set y-label on far left column
if i % 3 == 0:
ax.set(ylabel="p(z)")
# set the x-limits so we can see more detail
xlim[0] -= 0.2
xlim[1] += 0.2
ax.set(xlim=xlim, yticks=[])
# create the legend
for i, n in enumerate([1000, 2000, 5000, 10000]):
axes[0, 1].plot([], [], c=f"C{i}", label=f"{n} samples")
axes[0, 1].legend(
bbox_to_anchor=(0.5, 1.3),
loc="upper center",
ncol=4,
)
plt.show()
Notice that two of these galaxies may take upwards of 10,000 samples to converge (convolving over 10,000 samples takes 0.5 seconds / galaxy on my computer)
Now let's finally tackle posterior calculation with missing bands.
First, lets make a sample that has missing bands. Let's use the same degrader as we used above, except without the final QuantityCut that removed non-detections:
# set up the error model
errorModel = LSSTErrorModel()
def degrader(data: pd.DataFrame, seed: int = None) -> pd.DataFrame:
# apply the error model
data = errorModel(data, seed)
# make a cut on the observed i band
data = QuantityCut({"i": 25.3})(data, seed)
# introduce redshift incompleteness
data = InvRedshiftIncompleteness(0.8)(data, seed)
# introduce spectroscopic errors
# Oxygen lines (in angstroms)
OII = 3727
OIII = 5007
# 2% OII -> OIII confusion
data = LineConfusion(true_wavelen=OII, wrong_wavelen=OIII, frac_wrong=0.02)(data, seed)
# 1% OIII -> OII confusion
data = LineConfusion(true_wavelen=OIII, wrong_wavelen=OII, frac_wrong=0.01)(data, seed)
return data
creator_degraded = Creator(flowEngine, degrader=degrader)
samples_degraded = creator_degraded.sample(6, seed=0)
samples_degraded
redshift | u | u_err | g | g_err | r | r_err | i | i_err | z | z_err | y | y_err | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.733024 | 25.727907 | 0.152306 | 23.893915 | 0.010915 | 22.347567 | 0.005557 | 21.081667 | 0.005150 | 20.602454 | 0.005196 | 20.334081 | 0.005562 |
1 | 0.476285 | 26.427803 | 0.273279 | 26.374752 | 0.089551 | 25.534399 | 0.041842 | 25.270974 | 0.049108 | 25.224371 | 0.083048 | 24.902562 | 0.141114 |
2 | 0.260037 | 25.860308 | 0.170483 | 24.792525 | 0.022275 | 23.927480 | 0.010950 | 23.612819 | 0.012031 | 23.328087 | 0.015890 | 23.303974 | 0.034520 |
3 | 1.358129 | NaN | NaN | 27.100217 | 0.167924 | 25.727751 | 0.049675 | 23.919856 | 0.015275 | 23.233189 | 0.014721 | 22.342916 | 0.015189 |
4 | 0.579832 | 27.167752 | 0.486339 | 26.700148 | 0.119012 | 25.746879 | 0.050526 | 25.218503 | 0.046873 | 25.258271 | 0.085566 | 25.075233 | 0.163634 |
5 | 1.321674 | 26.801005 | 0.367880 | 25.484899 | 0.040770 | 24.783604 | 0.021680 | 23.914985 | 0.015215 | 23.239677 | 0.014797 | 22.622919 | 0.019132 |
You can see that galaxy 3 has a non-detection in the u band. FlowEngine
can handle missing values by marginalizing over that value. By default, FlowEngine
will marginalize over NaNs in the u band, using the grid u = np.linspace(25, 31, 10)
. This default grid should work in most cases, but you may want to change the flag for non-detections, use a different grid for the u band, or marginalize over non-detections in other bands. In order to do these things, you must supply FlowEngine
with marginalization rules in the form of the marg_rules
dictionary.
Let's imagine we want to use a different grid for u band marginalization. In order to determine what grid to use, we will create a histogram of non-detections in u band vs true u band magnitude (assuming year 10 LSST errors). This will tell me what are reasonable values of u to marginalize over.
# get true u band magnitudes
true_u = Creator(flowEngine, degrader=None).sample(10000, seed=0)["u"].to_numpy()
# get the observed u band magnitudes
obs_u = Creator(flowEngine, degrader=errorModel).sample(10000, seed=0)["u"].to_numpy()
# create the figure
fig, ax = plt.subplots(constrained_layout=True, dpi=100)
# plot the u band detections
ax.hist(true_u[~np.isnan(obs_u)], bins="fd", label="detected")
# plot the u band non-detections
ax.hist(true_u[np.isnan(obs_u)], bins="fd", label="non-detected")
ax.legend()
ax.set(xlabel="true u magnitude")
plt.show()
Based on this histogram, I will marginalize over u band values from 27 to 31. Like how I tested different numbers of error samples above, here I will test different resolutions for the u band grid.
I will provide our new u band grid in the marg_rules
dictionary, which will also include "flag"
which tells FlowEngine
what my flag for non-detections is.
In this simple example, we are using a fixed grid for the u band, but notice that the u band rule takes the form of a function - this is because the grid over which to marginalize can be a function of any of the other variables in the row.
If I wanted to marginalize over any other bands, I would need to include corresponding rules in marg_rules
too.
For this example, I will only calculate pdfs for galaxy 3, which is the galaxy with a non-detection in the u band. Also, similarly to how I tested the error convolution with a variable number of samples, I will test the marginalization with varying resolutions for the marginalized grid.
# dict to save the marginalized posteriors
pdfs_u_marginalized = {}
# iterate over variable grid resolution
for nbins in [10, 20, 50, 100]:
# set up the marginalization rules for this grid resolution
marg_rules = {
"flag": errorModel.settings["ndFlag"],
"u": lambda row: np.linspace(27, 31, nbins)
}
# calculate the posterior by marginalizing over u and sampling
# from the error distributions of the other galaxies
pdfs_u_marginalized[nbins] = creator_degraded.get_posterior(
samples_degraded.iloc[3:4],
grid=grid,
err_samples=10000,
seed=0,
marg_rules=marg_rules,
)
fig, ax = plt.subplots(dpi=100)
for i in [10, 20, 50, 100]:
pdfs_u_marginalized[i][0].plot_native(axes=ax, label=f"{i} bins")
ax.axvline(samples_degraded.iloc[3]["redshift"], label="True redshift", c="k")
ax.legend()
ax.set(xlabel="Redshift")
plt.show()
Notice that the resolution with only 10 bins is sufficient for this marginalization.
In this example, only one of the bands featured a non-detection, but you can easily marginalize over more bands by including corresponding rules in the marg_rules
dict. For example, let's artificially force a non-detection in the y band as well:
sample_double_degraded = samples_degraded.iloc[3:4].copy()
sample_double_degraded.iloc[0, 11:] *= np.nan
sample_double_degraded
redshift | u | u_err | g | g_err | r | r_err | i | i_err | z | z_err | y | y_err | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
3 | 1.358129 | NaN | NaN | 27.100217 | 0.167924 | 25.727751 | 0.049675 | 23.919856 | 0.015275 | 23.233189 | 0.014721 | NaN | NaN |
# set up the marginalization rules for u and y marginalization
marg_rules = {
"flag": errorModel.settings["ndFlag"],
"u": lambda row: np.linspace(27, 31, 10),
"y": lambda row: np.linspace(21, 25, 10),
}
# calculate the posterior by marginalizing over u and y, and sampling
# from the error distributions of the other galaxies
pdf_double_marginalized = creator_degraded.get_posterior(
sample_double_degraded,
grid=grid,
err_samples=10000,
seed=0,
marg_rules=marg_rules,
)
fig, ax = plt.subplots(dpi=100)
pdf_double_marginalized[0].plot_native(axes=ax)
ax.axvline(sample_double_degraded.iloc[0]["redshift"], label="True redshift", c="k")
ax.legend()
ax.set(xlabel="Redshift")
plt.show()
Note that marginalizing over multiple bands quickly gets expensive