JWST Disk Modeling (Experimental)
PSFs for JWST are substantially more complex, and many tools exist to process data from this unique observatory. GRaTeR-JAX incorporates functionality from `Winnie <https://github.com/kdlawson/Winnie>`__ for disk modeling and optimization using JWST data, including MCRDI. This aspect of GRaTeR-JAX is currently experimental, so use at your own risk and please report any issues or bugs using our GitHub.
[1]:
import os
os.environ["STPSF_PATH"] = 'stpsf-data'
os.environ["STPSF_EXT_PATH"] = 'stpsf-data'
os.environ["PYSYN_CDBS"] = "cdbs"
from astropy.io import fits
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.3'
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)
Loading Files
[2]:
import sys
sys.path.append('../..')
Processing Fits File
[3]:
from grater_jax.optimization.optimize_framework import OptimizeUtils
fits_image_filepath = "test_images/" + 'jw04050015001_03106_00001_nrca2_calints' + ".fits"
hdul = fits.open(fits_image_filepath)
#Displays File Info
hdul.info()
# Gets Image
target_image = OptimizeUtils.process_image(hdul['SCI'].data[0], bounds=(120, 360, 120, 360))
# Displays Image
plt.imshow(target_image, origin='lower', cmap='inferno')
plt.title("Image from Fits File")
WARNING: Environment variable $WEBBPSF_EXT_PATH is not set!
Setting WEBBPSF_EXT_PATH to STPSF_PATH directory:
stpsf-data
webbpsf_ext log messages of level WARN and above will be shown.
webbpsf_ext log outputs will be directed to the screen.
Could not import CuPy. Setting: use_gpu=False (i.e., using CPU operations).
Filename: test_images/jw04050015001_03106_00001_nrca2_calints.fits
No. Name Ver Type Cards Dimensions Format
0 PRIMARY 1 PrimaryHDU 346 ()
1 SCI 1 ImageHDU 113 (480, 480, 1) float32
2 ERR 1 ImageHDU 11 (480, 480, 1) float64
3 DQ 1 ImageHDU 12 (480, 480, 1) int32 (rescales to uint32)
4 AREA 1 ImageHDU 9 (320, 320) float32
5 INT_TIMES 1 BinTableHDU 24 5R x 7C [J, D, D, D, D, D, D]
6 VAR_POISSON 1 ImageHDU 10 (320, 320, 5) float32
7 VAR_RNOISE 1 ImageHDU 10 (320, 320, 5) float32
8 VAR_FLAT 1 ImageHDU 10 (320, 320, 5) float32
9 ASDF 1 BinTableHDU 11 1R x 1C [27184B]
10 IMSHIFTS 1 ImageHDU 8 (2, 1) float64
11 MASKOFFS 1 ImageHDU 8 (2, 1) float64
[3]:
Text(0.5, 1.0, 'Image from Fits File')
Importing Necessary Modules
[4]:
from grater_jax.optimization.optimize_framework import Optimizer, OptimizeUtils
from grater_jax.disk_model.objective_functions import objective_model, objective_ll, objective_fit, Parameter_Index
from grater_jax.disk_model.SLD_ojax import ScatteredLightDisk
from grater_jax.disk_model.SLD_utils import *
Obtaining target image and Generating error map
[5]:
target_image = OptimizeUtils.process_image(hdul['SCI'].data[0], bounds=(120, 360, 120, 360))
err_map = OptimizeUtils.process_image(OptimizeUtils.create_empirical_err_map(hdul['ERR'].data[0]), bounds=(40, 280, 40, 280)) #, outlier_pixels=[(57, 68)]))
[6]:
folder = "stellar_psf_images/"
fits_files = [f for f in os.listdir(folder) if f.endswith(".fits")]
image_list = []
for fname in fits_files:
path = os.path.join(folder, fname)
with fits.open(path) as hdul_stellar:
sci_data = hdul_stellar['SCI'].data[0] # shape (H, W)
image_list.append(OptimizeUtils.process_image(sci_data, bounds=(120, 360, 120, 360)))
image_stack_np = np.stack(image_list) # shape (N, H, W)
image_stack_jax = jnp.array(image_stack_np) # shape (N, H, W)
Defining stellar psf reference images
[7]:
fig, axes = plt.subplots(len(image_stack_jax)//3, 3, figsize=(20, 12))
for i in range(0, len(image_stack_jax)):
axes[i//3][i%3].imshow(image_stack_jax[i], origin='lower', cmap='inferno')
Creating WinniePSF NIRCam Object
[8]:
from grater_jax.disk_model.winnie_jwst_fm import generate_nircam_psf_grid
from grater_jax.disk_model.SLD_utils import Winnie_PSF
import numpy as np
import webbpsf_ext
# 1. Star center in aligned image (shifted to nominal coronagraph center)
cent = np.array([120, 120])
# 2. No offset between star and coronagraph center for each roll
dc_rolls = np.array([[0, 0], [0, 0]])
c_coron_rolls = cent + dc_rolls # Final per-roll star positions (pixels)
# 3. PSF grid settings
fov_pix = 153 # Field of view in detector pixels
osamp = 2 # PSF oversampling factor
source_spectrum = None # Flat spectrum (default stellar)
# 4. Define the NIRCam instrument
inst = webbpsf_ext.NIRCam_ext(
fov_pix=fov_pix,
oversample=osamp,
filter='F200W',
pupil_mask='MASK210R',
module='A',
channel='SHORT',
detector='NRCA2'
)
# 5. Set optical path differences and pupil misalignments
inst.options['pupil_shift_x'] = -0.012
inst.options['pupil_shift_y'] = -0.008
inst.options['defocus_waves'] = -0.00865
inst.options['pupil_rotation'] = -0.595
# 6. Precompute interpolation coefficients for faster PSF generation
inst.gen_psf_coeff()
inst.gen_wfemask_coeff(large_grid=True)
# 7. Generate PSF grid
print("Generating PSFs...")
psfs, psf_inds_rolls, im_mask_rolls, psf_offsets = generate_nircam_psf_grid(
inst,
c_coron_rolls,
source_spectrum=source_spectrum,
normalize=True,
log_rscale=True,
nr=5,
rmax=10**0.5,
ntheta=4,
use_coeff=True
)
test_parangs = [0.]
n_unique_inds = len(jnp.unique(psf_inds_rolls))
psf_params_F200M = Winnie_PSF.init(psfs, psf_inds_rolls, im_mask_rolls, psf_offsets, test_parangs, n_unique_inds)
Generating PSFs...
Initializing Disk Model Parameters
[9]:
start_disk_params = Parameter_Index.disk_params.copy()
start_spf_params = InterpolatedUnivariateSpline_SPF.params.copy()
start_psf_params = psf_params_F200M
start_stellar_psf_params = PositionalStellarPSF.params.copy()
start_misc_params = Parameter_Index.misc_params.copy()
start_disk_params['sma'] = 70.12
start_disk_params['inclination'] = 72.40
start_disk_params['position_angle'] = 133.22
start_disk_params['alpha_in'] = 11.99
start_disk_params['alpha_out'] = -4.25
start_disk_params['gamma'] = 1.11
start_spf_params['num_knots'] = 7 # int(row["Knots"])
start_spf_params['knot_values'] = 1e-8 * OptimizeUtils.convert_dhg_params_to_spline_params(0.83, -0.73, 0.94, start_spf_params)
start_stellar_psf_params['stellar_weights'] = 0.004*jnp.ones(len(image_stack_jax))
start_stellar_psf_params['stellar_xs'] = 120*jnp.ones(len(image_stack_jax))
start_stellar_psf_params['stellar_ys'] = 120*jnp.ones(len(image_stack_jax))
start_misc_params['nx'] = 240
start_misc_params['ny'] = 240
start_misc_params['distance'] = 80.2066
start_misc_params['pxInArcsec'] = 0.032
start_disk_params['x_center'] = start_misc_params['nx']/2
start_disk_params['y_center'] = start_misc_params['ny']/2
opt = Optimizer(ScatteredLightDisk, DustEllipticalDistribution2PowerLaws, InterpolatedUnivariateSpline_SPF, Winnie_PSF,
start_disk_params, start_spf_params, start_psf_params, start_misc_params, StellarPSFModel=PositionalStellarPSF,
stellar_psf_params=start_stellar_psf_params)
opt.define_reference_images(image_stack_jax)
opt.inc_bound_knots() # call this method only if you have a good predifined inclination, otherwise don't use it
opt.initialize_knots(target_image-opt.compute_stellar_psf_image())
opt.initialize_stellar_psf_weights(target_image) # <-- Necessary step in order to get good initial stellar psf weight estimates (can manually do this as well)
Jit Compiling Model and Gradient Functions
[10]:
opt.jit_compile_model() # <-- opt.log_likelihood is also jit compiled by this
opt.jit_compile_gradient(target_image, err_map)
[11]:
plt.imshow(opt.get_disk(), cmap='inferno', origin='lower') # <-- Just disk model (without off axis and on axis psfs)
[11]:
<matplotlib.image.AxesImage at 0x7f54f0560050>
[12]:
fig, axes = plt.subplots(1,3, figsize=(20,10))
init_image = opt.get_model()
mask = OptimizeUtils.get_mask(target_image)
vmin = np.nanpercentile(target_image[mask], 1)
vmax = np.nanpercentile(target_image[mask], 99)
im = axes[0].imshow(target_image, origin='lower', cmap='inferno')
axes[0].set_title("Target Image")
plt.colorbar(im, ax=axes[0], shrink=0.75)
im.set_clim(vmin, vmax)
im = axes[1].imshow(init_image, origin='lower', cmap='inferno')
axes[1].set_title("Initial Fit")
plt.colorbar(im, ax=axes[1], shrink=0.75)
im.set_clim(vmin, vmax)
im = axes[2].imshow(target_image-init_image, origin='lower', cmap='inferno')
axes[2].set_title("Residual")
plt.colorbar(im, ax=axes[2], shrink=0.75)
# im.set_clim(vmin, vmax)
[12]:
<matplotlib.colorbar.Colorbar at 0x7f52dc4691d0>
[13]:
opt.print_params()
Disk Params: {'accuracy': 0.005, 'alpha_in': 11.99, 'alpha_out': -4.25, 'sma': 70.12, 'e': 0.0, 'ksi0': 3.0, 'gamma': 1.11, 'beta': 1.0, 'rmin': 0.0, 'dens_at_r0': 1.0, 'inclination': 72.4, 'position_angle': 133.22, 'x_center': 120.0, 'y_center': 120.0, 'halfNbSlices': 25, 'omega': 0.0}
SPF Params: {'backscatt_bound': Array(-0.95319067, dtype=float64, weak_type=True), 'forwardscatt_bound': Array(0.95319067, dtype=float64, weak_type=True), 'num_knots': 7, 'knot_values': Array([8.64268514, 2.90095009, 1.55257508, 1. , 0.71196398,
0.53980746, 0.42735314], dtype=float64)}
PSF Params: <grater_jax.disk_model.winnie_class.WinniePSF object at 0x7f55fc4cbd90>
Stellar PSF Params: {'stellar_weights': array([-2.8942404e-03, -1.2769293e-03, 3.4055905e-03, 9.9285717e-05,
-9.0930314e-04, 1.6361177e-03, 1.8586181e-02, 8.2717324e-04,
4.5218747e-03], dtype=float32), 'stellar_xs': Array([120., 120., 120., 120., 120., 120., 120., 120., 120.], dtype=float64), 'stellar_ys': Array([120., 120., 120., 120., 120., 120., 120., 120., 120.], dtype=float64)}
Misc Params: {'distance': 80.2066, 'pxInArcsec': 0.032, 'nx': 240, 'ny': 240, 'halfNbSlices': 25, 'flux_scaling': Array(5.86315494, dtype=float64)}
[14]:
print(opt._get_raw_gradient(target_image, err_map)) # <-- demonstrates raw gradient output
(Array([-19980.4929142 , -987.92820553, 26836.43973602, 1217.75076655,
92838.51588009, 42620.7472783 , -44227.47006375, 26694.94876916,
0. , 119919.08951712, 3702.45548638, -430.44375083,
11710.70403315, -13541.98545437, 0. , 0. ], dtype=float64), Array([-4935.68858732, 13313.58739363, 18903.71357588, 29542.11974516,
18687.35831691, 67315.61379436, 36084.65056183], dtype=float64), Array([[ -0. , -0. , -0. , ..., -1522.4349 ,
968.89606 , 1181.5336 ],
[ -0. , 0. , -0. , ..., 2209.86 ,
3402.3254 , 4419.3076 ],
[ 0. , -0. , 0. , ..., -2475.6865 ,
1525.8789 , -1213.2253 ],
...,
[ 3052.0679 , 537.2665 , -435.12085 , ..., 430.57645 ,
-605.4521 , -325.32025 ],
[-6026.1196 , 2004.4668 , 4401.2227 , ..., 219.3026 ,
325.14252 , 288.58676 ],
[ 187.14015 , -5901.0493 , 3676.369 , ..., -378.35062 ,
8.064567, 1305.953 ]], dtype=float32))
[15]:
# Number of stellar psf images
num_refs = len(image_stack_jax)
# Logscaling helps when you want to fit parameters across orders of magnitude
fit_keys = ['alpha_in', 'alpha_out', 'sma', 'inclination', 'position_angle', 'x_center', 'y_center', 'knot_values',
'stellar_weights', 'stellar_xs', 'stellar_ys', 'e']
logscaled_params = ['knot_values', 'stellar_weights']
array_params = ['knot_values', 'stellar_weights', 'stellar_xs', 'stellar_ys']
bounds = ([0.1, -15, 0, 0, 0, 110, 110, np.zeros(7), np.zeros(num_refs), 110*np.ones(num_refs), 110*np.ones(num_refs), 0],
[20, -0.1, 150, 180, 400, 130, 130, 1e4*np.ones(7), np.ones(num_refs), 130*np.ones(num_refs), 130*np.ones(num_refs), 1])
opt.scipy_bounded_optimize(fit_keys, bounds, logscaled_params, array_params, target_image, err_map, use_grad = False,
disp_soln = True, iters = 2000, eps = 1e-40, gtol=1e-40)
optimal_image = opt.get_model()
message: ABNORMAL:
success: False
status: 2
fun: 30079764.036093008
x: [ 1.201e+01 -4.313e+00 ... 1.200e+02 3.275e-26]
nit: 7
jac: [ 1.515e+03 -5.822e+04 ... 1.561e+07 -3.818e+31]
nfev: 2279
njev: 53
hess_inv: <42x42 LbfgsInvHessProduct with dtype=float64>
Results
[16]:
fig, axes = plt.subplots(1,4, figsize=(20,10))
mask = OptimizeUtils.get_mask(target_image)
vmin = np.nanpercentile(target_image[mask], 1)
vmax = np.nanpercentile(target_image[mask], 99)
im = axes[0].imshow(target_image, origin='lower', cmap='inferno')
axes[0].set_title("Original Image")
plt.colorbar(im, ax=axes[0], shrink=0.75)
im.set_clim(vmin, vmax)
im = axes[1].imshow(optimal_image, origin='lower', cmap='inferno')
axes[1].set_title("Disk Fit")
plt.colorbar(im, ax=axes[1], shrink=0.75)
im.set_clim(vmin, vmax)
im = axes[2].imshow(target_image-optimal_image, origin='lower', cmap='inferno')
axes[2].set_title("Residual")
plt.colorbar(im, ax=axes[2], shrink=0.75)
# im.set_clim(vmin, vmax)
im = axes[3].imshow(np.where(err_map == 0., 0., (target_image-optimal_image) / err_map), origin='lower', cmap='inferno')
axes[3].set_title("Signal-To-Noise")
plt.colorbar(im, ax=axes[3], shrink=0.75)
# im.set_clim(vmin, vmax)
[16]:
<matplotlib.colorbar.Colorbar at 0x7f52dc12c2d0>
[17]:
#opt.scale_spline_to_fixed_point(0, 1)
opt.print_params()
opt.log_likelihood(target_image, err_map)
Disk Params: {'accuracy': 0.005, 'alpha_in': np.float64(12.00647568868851), 'alpha_out': np.float64(-4.313420175122997), 'sma': np.float64(70.9785599373245), 'e': np.float64(3.275425824845555e-26), 'ksi0': 3.0, 'gamma': 1.11, 'beta': 1.0, 'rmin': 0.0, 'dens_at_r0': 1.0, 'inclination': np.float64(72.25153989407819), 'position_angle': np.float64(133.10696544363904), 'x_center': np.float64(119.89251809059586), 'y_center': np.float64(120.10748246408362), 'halfNbSlices': 25, 'omega': 0.0}
SPF Params: {'backscatt_bound': Array(-0.95319067, dtype=float64, weak_type=True), 'forwardscatt_bound': Array(0.95319067, dtype=float64, weak_type=True), 'num_knots': 7, 'knot_values': array([6.9034761 , 2.3445699 , 1.26327209, 0.81751751, 0.58417272,
0.44423601, 0.35257508])}
PSF Params: <grater_jax.disk_model.winnie_class.WinniePSF object at 0x7f55fc4cbd90>
Stellar PSF Params: {'stellar_weights': array([1.00007244e-14, 1.00000051e-14, 2.85926357e-03, 8.67697704e-05,
1.00009749e-14, 1.38535321e-03, 1.51098340e-02, 7.06240169e-04,
3.78046051e-03]), 'stellar_xs': array([119.9999479 , 120. , 119.97261139, 120.03755954,
120. , 120.03810242, 119.97578763, 119.97178155,
119.97301516]), 'stellar_ys': array([120. , 120.00055637, 119.96970243, 120.03745583,
119.99998068, 120.03636524, 120.02508112, 119.97107541,
120.03442554])}
Misc Params: {'distance': 80.2066, 'pxInArcsec': 0.032, 'nx': 240, 'ny': 240, 'halfNbSlices': 25, 'flux_scaling': Array(5.86315494, dtype=float64)}
[17]:
Array(-30079764.03617625, dtype=float64)
Final Disk Results
[18]:
plt.imshow(opt.get_disk(), cmap='inferno', origin='lower')
[18]:
<matplotlib.image.AxesImage at 0x7f5288070e10>
[19]:
plt.imshow(opt.get_model()-opt.compute_stellar_psf_image(), cmap='inferno', origin='lower')
[19]:
<matplotlib.image.AxesImage at 0x7f52880bdd10>
Running MCMC
[20]:
# Running optimization method
fit_keys = ['alpha_in', 'alpha_out', 'sma', 'inclination', 'position_angle', 'x_center', 'y_center', 'knot_values', 'stellar_weights',
'stellar_xs', 'stellar_ys', 'e']
logscaled_params = ['knot_values', 'stellar_weights']
array_params = ['knot_values', 'stellar_weights', 'stellar_xs', 'stellar_ys']
bounds = ([0.1, -15, 0, 0, 0, 110, 110, np.zeros(7), np.zeros(num_refs), 110*np.ones(num_refs), 110*np.ones(num_refs), 0],
[200, -0.1, 150, 180, 400, 130, 130, 1e4*np.ones(7), np.ones(num_refs), 130*np.ones(num_refs), 130*np.ones(num_refs), 1])
mc_model = opt.mcmc(fit_keys, logscaled_params, array_params, target_image, err_map, bounds, nwalkers=100, niter=100, burns=20,
scale_objective_function_for_shape = True)
Running burn-in...
100%|██████████| 20/20 [00:29<00:00, 1.48s/it]
Running production...
100%|██████████| 100/100 [03:04<00:00, 1.85s/it]
[21]:
mc_soln = np.median(mc_model.sampler.flatchain, axis=0)
img = opt.get_model()
Results
[22]:
fig, axes = plt.subplots(1,3, figsize=(20,10))
mask = OptimizeUtils.get_mask(target_image)
vmin = np.nanpercentile(target_image[mask], 1)
vmax = np.nanpercentile(target_image[mask], 99)
im = axes[0].imshow(target_image, origin='lower', cmap='inferno')
axes[0].set_title("Original Image")
plt.colorbar(im, ax=axes[0], shrink=0.75)
im.set_clim(vmin, vmax)
im = axes[1].imshow(img, origin='lower', cmap='inferno')
axes[1].set_title("MCMC Fit")
plt.colorbar(im, ax=axes[1], shrink=0.75)
im.set_clim(vmin, vmax)
im = axes[2].imshow(target_image-img, origin='lower', cmap='inferno')
axes[2].set_title("Residual")
plt.colorbar(im, ax=axes[2], shrink=0.75)
# im.set_clim(vmin, vmax)
plt.title("Post MCMC")
[22]:
Text(0.5, 1.0, 'Post MCMC')
Plotting MCMC Results
[23]:
labels = ['alpha_in', 'alpha_out', 'inclination', 'position_angle', 'radius', 'xc', 'yc']
for i in range(0, opt.spf_params['num_knots']):
labels.append('k'+str(i+1))
for i in range(0, len(opt.stellar_psf_params['stellar_weights'])):
labels.append('sw'+str(i+1))
for i in range(0, len(opt.stellar_psf_params['stellar_xs'])):
labels.append('sw'+str(i+1))
for i in range(0, len(opt.stellar_psf_params['stellar_ys'])):
labels.append('sw'+str(i+1))
labels.append('eccentricity')
mc_model.plot_chains(labels)
Corner Plot
[24]:
fig = mc_model.show_corner_plot(labels, truths=mc_soln, quiet = True)
[25]:
print(dict(zip(labels, mc_model.get_theta_median())))
{'alpha_in': np.float64(12.006460238659642), 'alpha_out': np.float64(-4.3171309957277195), 'inclination': np.float64(70.97965116543598), 'position_angle': np.float64(72.25152625815001), 'radius': np.float64(133.10618024638114), 'xc': np.float64(119.89468431578625), 'yc': np.float64(120.10775385350357), 'k1': np.float64(1.9330375876390566), 'k2': np.float64(0.8539938599656085), 'k3': np.float64(0.23469746085704563), 'k4': np.float64(-0.20376625082441088), 'k5': np.float64(-0.53405968974975), 'k6': np.float64(-0.8113969066159489), 'k7': np.float64(-1.0411228400177766), 'sw1': np.float64(119.99961205393964), 'sw2': np.float64(120.0002066140501), 'sw3': np.float64(119.96981614683405), 'sw4': np.float64(120.04028604249146), 'sw5': np.float64(119.99972133021576), 'sw6': np.float64(120.03729442965675), 'sw7': np.float64(120.02090752139796), 'sw8': np.float64(119.96924099883408), 'sw9': np.float64(120.03891989099102), 'eccentricity': np.float64(0.000666214184201438)}