Skip to content

Commit

Permalink
Moved run_all() back to vetters after some discussion
Browse files Browse the repository at this point in the history
  • Loading branch information
m-dallas committed Mar 6, 2024
1 parent 06ac697 commit fe1cf0a
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 252 deletions.
144 changes: 1 addition & 143 deletions exovetter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import sys
import warnings
import vetters as vet

import numpy as np

__all__ = ['sine', 'estimate_scatter', 'mark_transit_cadences', 'median_detrend',
'plateau', 'set_median_flux_to_zero', 'set_median_flux_to_one', 'sigmaClip',
'get_mast_tce', 'WqedLSF', 'compute_phases', 'first_epoch', 'run_all']
'get_mast_tce', 'WqedLSF', 'compute_phases', 'first_epoch']

def sine(x, order, period=1):
"""Sine function for SWEET vetter."""
Expand Down Expand Up @@ -654,144 +653,3 @@ def first_epoch(epoch, period, lc):
first_epoch = epoch + N*period

return first_epoch

def run_all(tces, lcs, vetters=[vet.VizTransits(), vet.ModShift(), vet.Lpp(), vet.OddEven(), vet.TransitPhaseCoverage(), vet.Sweet(), vet.LeoTransitEvents()], plot=False, verbose=False, plot_dir=None):
# TODO Add centroid, maybe rething plotting in general since plotting uses vetter.plot which essentially doubles runtime,
# probably should run initially with vet.run(plot=True) and not store them unless run_all plot=True
"""Runs vetters and packs results into a dataframe.
Parameters
----------
tces: list of tce objects to vet on
lc: list of lightkurve objects to vet on
vetters : list
List of vetter classes to run
plot : bool
Toggle diagnostic plots
plot_dir : str
path to store plots in, defaults to current working directory
verbose : bool
Toggle timing info and other print statements
Returns
------------
results : dataframe
Pandas dataframe of all the numerical results from the vetters
"""

results_dicts = [] # initialize a list to pack results from each tce into
tce_names = []
run_start = py_time.time()

if plot_dir is None:
plot_dir = os.getcwd()

if plot or verbose:
for tce in tces:
if 'target' not in tce.keys():
print("ERROR: Please supply a 'target' key to all input tces to use the plot or verbose parameters")
return

for tce, lc in zip(tces, lcs):
if verbose:
print('Vetting', tce['target'], ':')

tce_names.append(tce['target'])
results_list = [] # initialize a list to pack result dictionaries into

# run each vetter, if plotting is true fill the figures into a list to save later
plot_figures = []
for vetter in vetters:
time_start = py_time.time()
vetter_results = vetter.run(tce, lc)

if plot:
if vetter.__class__.__name__ != 'VizTransits' and vetter.__class__.__name__ != 'LeoTransitEvents':
# viz_transits generates 2 figures so it's handled later, LeoTransitEvents just doesn't have a plot
vetter.plot()
vetter_plot = plt.gcf()
vetter_plot.suptitle(tce['target']+' '+vetter.__class__.__name__)
vetter_plot.tight_layout()
plt.close()
plot_figures.append(vetter_plot)

if verbose:
time_end = py_time.time()
print(vetter.__class__.__name__, 'finished in', time_end - time_start, 's.')

results_list.append(vetter_results)

if verbose: # add some whitespace for readability
print()

if plot: # save a pdf of each figure made for that vetter
diagnostic_plot = PdfPages(plot_dir+tce['target']+'.pdf') # initialize a pdf to save each figure into

# plot the lightcurve with epochs oeverplotted
time, flux, time_offset_str = lightkurve_utils.unpack_lk_version(lc, "flux") # noqa: E50
period = tce["period"].to_value(u.day)
dur = tce["duration"].to_value(u.day)

time_offset_q = getattr(exo_const, time_offset_str)
epoch = tce.get_epoch(time_offset_q).to_value(u.day)
intransit = utils.mark_transit_cadences(time, period, epoch, dur, num_durations=3, flags=None)

fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(9,5))
ax1.plot(time, flux, lw=0.4);
ax1.axvline(x=epoch, lw='0.6', color='r', label='epoch');
ax1.fill_between(time, 0,1, where=intransit, transform=ax1.get_xaxis_transform(), color='r', alpha=0.15, label='in transit')

ax1.set_ylabel('Flux')
ax1.set_xlabel('Time '+time_offset_str)
if 'target' in tce:
ax1.set_title(tce['target']);

ax1.legend();
lightcurve_plot = plt.gcf()
plt.close()
diagnostic_plot.savefig(lightcurve_plot)

# run viz_transits plots
transit = VizTransits(transit_plot=True, folded_plot=False).run(tce, lc)
transit_plot = plt.gcf()
transit_plot.suptitle(tce['target']+' Transits')
transit_plot.tight_layout()
plt.close()
diagnostic_plot.savefig(transit_plot)

folded = VizTransits(transit_plot=False, folded_plot=True).run(tce, lc)
folded_plot = plt.gcf()
folded_plot.suptitle(tce['target']+' Folded Transits')
folded_plot.tight_layout()
plt.close()
diagnostic_plot.savefig(folded_plot)

# Save each diagnostic plot ran on that tce/lc
for plot in plot_figures:
diagnostic_plot.savefig(plot)

diagnostic_plot.close()

# put all values from each results dictionary into a single dictionary
results_dict = {k: v for d in results_list for k, v in d.items()}

# delete dictionary entries that are huge arrays to save space
if results_dict.get('plot_data'):
del results_dict['plot_data']

# add the dictionary to the final list
results_dicts.append(results_dict)

results_df = pd.DataFrame(results_dicts) # Put the values from each result dictionary into a dataframe

results_df.insert(loc=0, column='tce', value=tce_names)
if verbose:
print('Execution time:', (py_time.time() - run_start), 's')

return results_df
143 changes: 142 additions & 1 deletion exovetter/vetters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

__all__ = ['BaseVetter', 'ModShift', 'Lpp', 'OddEven',
'TransitPhaseCoverage', 'Sweet', 'Centroid',
'VizTransits', 'LeoTransitEvents']
'VizTransits', 'LeoTransitEvents', 'run_all']

class BaseVetter(ABC):
"""Base class for vetters.
Expand Down Expand Up @@ -900,3 +900,144 @@ def run(self, tce, lightcurve, plot=False):

def plot(self):
pass

def run_all(tces, lcs, vetters=[VizTransits(), ModShift(), Lpp(), OddEven(), TransitPhaseCoverage(), Sweet(), LeoTransitEvents()], plot=False, verbose=False, plot_dir=None):
# TODO Add centroid, maybe rething plotting in general since plotting uses vetter.plot which essentially doubles runtime,
# probably should run initially with vet.run(plot=True) and not store them unless run_all plot=True
"""Runs vetters and packs results into a dataframe.
Parameters
----------
tces: list of tce objects to vet on
lc: list of lightkurve objects to vet on
vetters : list
List of vetter classes to run
plot : bool
Toggle diagnostic plots
plot_dir : str
path to store plots in, defaults to current working directory
verbose : bool
Toggle timing info and other print statements
Returns
------------
results : dataframe
Pandas dataframe of all the numerical results from the vetters
"""

results_dicts = [] # initialize a list to pack results from each tce into
tce_names = []
run_start = py_time.time()

if plot_dir is None:
plot_dir = os.getcwd()

if plot or verbose:
for tce in tces:
if 'target' not in tce.keys():
print("ERROR: Please supply a 'target' key to all input tces to use the plot or verbose parameters")
return

for tce, lc in zip(tces, lcs):
if verbose:
print('Vetting', tce['target'], ':')

tce_names.append(tce['target'])
results_list = [] # initialize a list to pack result dictionaries into

# run each vetter, if plotting is true fill the figures into a list to save later
plot_figures = []
for vetter in vetters:
time_start = py_time.time()
vetter_results = vetter.run(tce, lc)

if plot:
if vetter.__class__.__name__ != 'VizTransits' and vetter.__class__.__name__ != 'LeoTransitEvents':
# viz_transits generates 2 figures so it's handled later, LeoTransitEvents just doesn't have a plot
vetter.plot()
vetter_plot = plt.gcf()
vetter_plot.suptitle(tce['target']+' '+vetter.__class__.__name__)
vetter_plot.tight_layout()
plt.close()
plot_figures.append(vetter_plot)

if verbose:
time_end = py_time.time()
print(vetter.__class__.__name__, 'finished in', time_end - time_start, 's.')

results_list.append(vetter_results)

if verbose: # add some whitespace for readability
print()

if plot: # save a pdf of each figure made for that vetter
diagnostic_plot = PdfPages(plot_dir+tce['target']+'.pdf') # initialize a pdf to save each figure into

# plot the lightcurve with epochs oeverplotted
time, flux, time_offset_str = lightkurve_utils.unpack_lk_version(lc, "flux") # noqa: E50
period = tce["period"].to_value(u.day)
dur = tce["duration"].to_value(u.day)

time_offset_q = getattr(exo_const, time_offset_str)
epoch = tce.get_epoch(time_offset_q).to_value(u.day)
intransit = utils.mark_transit_cadences(time, period, epoch, dur, num_durations=3, flags=None)

fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(9,5))
ax1.plot(time, flux, lw=0.4);
ax1.axvline(x=epoch, lw='0.6', color='r', label='epoch');
ax1.fill_between(time, 0,1, where=intransit, transform=ax1.get_xaxis_transform(), color='r', alpha=0.15, label='in transit')

ax1.set_ylabel('Flux')
ax1.set_xlabel('Time '+time_offset_str)
if 'target' in tce:
ax1.set_title(tce['target']);

ax1.legend();
lightcurve_plot = plt.gcf()
plt.close()
diagnostic_plot.savefig(lightcurve_plot)

# run viz_transits plots
transit = VizTransits(transit_plot=True, folded_plot=False).run(tce, lc)
transit_plot = plt.gcf()
transit_plot.suptitle(tce['target']+' Transits')
transit_plot.tight_layout()
plt.close()
diagnostic_plot.savefig(transit_plot)

folded = VizTransits(transit_plot=False, folded_plot=True).run(tce, lc)
folded_plot = plt.gcf()
folded_plot.suptitle(tce['target']+' Folded Transits')
folded_plot.tight_layout()
plt.close()
diagnostic_plot.savefig(folded_plot)

# Save each diagnostic plot ran on that tce/lc
for plot in plot_figures:
diagnostic_plot.savefig(plot)

diagnostic_plot.close()

# put all values from each results dictionary into a single dictionary
results_dict = {k: v for d in results_list for k, v in d.items()}

# delete dictionary entries that are huge arrays to save space
if results_dict.get('plot_data'):
del results_dict['plot_data']

# add the dictionary to the final list
results_dicts.append(results_dict)

results_df = pd.DataFrame(results_dicts) # Put the values from each result dictionary into a dataframe

results_df.insert(loc=0, column='tce', value=tce_names)
if verbose:
print('Execution time:', (py_time.time() - run_start), 's')

return results_df
Loading

0 comments on commit fe1cf0a

Please sign in to comment.