Skip to content

Commit

Permalink
Integrate the Logger into code.
Browse files Browse the repository at this point in the history
  • Loading branch information
DinoBektesevic committed Feb 5, 2024
1 parent 00a9d0a commit 6f35a95
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 156 deletions.
5 changes: 2 additions & 3 deletions src/kbmod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import time

logging.Formatter.converter = time.gmtime
logging.basicConfig(level=logging.DEBUG,
format='[%(asctime)s %(levelname)s %(name)s] %(message)s')
logging.basicConfig(format='[%(asctime)s %(levelname)s %(name)s] %(message)s')

from . import (
analysis,
Expand All @@ -24,6 +23,6 @@
run_search,
)

from .search import PSF, RawImage, LayeredImage, ImageStack, StackSearch
from .search import PSF, RawImage, LayeredImage, ImageStack, StackSearch, Logging
from .standardizers import Standardizer, StandardizerConfig
from .image_collection import ImageCollection
33 changes: 15 additions & 18 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .result_list import ResultList, ResultRow


logger = kb.Logging.getLogger(__name__)


class PostProcess:
"""This class manages the post-processing utilities used to filter out and
otherwise remove false positives from the KBMOD search. This includes,
Expand Down Expand Up @@ -82,17 +85,13 @@ def load_and_filter_results(
else:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs)

print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
logger.info("Retrieving Results")
while likelihood_limit is False:
print("Getting results...")
logger.info("Getting results...")
results = search.get_results(res_num, chunk_size)
print("---------------------------------------")
print("Chunk Start = %i" % res_num)
print("Chunk Max Likelihood = %.2f" % results[0].lh)
print("Chunk Min. Likelihood = %.2f" % results[-1].lh)
print("---------------------------------------")
logger.info("Chunk Start = %i" % res_num)
logger.info("Chunk Max Likelihood = %.2f" % results[0].lh)
logger.info("Chunk Min. Likelihood = %.2f" % results[-1].lh)

result_batch = ResultList(self._mjds)
for i, trj in enumerate(results):
Expand All @@ -111,7 +110,7 @@ def load_and_filter_results(
total_count += 1

batch_size = result_batch.num_results()
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
logger.info("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
apply_clipped_sigma_g(clipper, result_batch, self.num_cores)
result_batch.apply_filter(stats_filter)
Expand Down Expand Up @@ -208,16 +207,14 @@ def apply_stamp_filter(
all_valid_inds = []

# Run the stamp creation and filtering in batches of chunk_size.
print("---------------------------------------")
print("Applying Stamp Filtering")
print("---------------------------------------", flush=True)
logger.info("Applying Stamp Filtering")
start_time = time.time()
start_idx = 0
if result_list.num_results() <= 0:
print("Skipping. Nothing to filter.")
logger.warning("Skipping stamp filtering. Nothing to filter.")
return

print("Stamp filtering %i results" % result_list.num_results())
logger.info("Stamp filtering %i results" % result_list.num_results())
while start_idx < result_list.num_results():
end_idx = min([start_idx + chunk_size, result_list.num_results()])

Expand Down Expand Up @@ -258,11 +255,11 @@ def apply_stamp_filter(

# Do the actual filtering of results
result_list.filter_results(all_valid_inds)
print("Keeping %i results" % result_list.num_results(), flush=True)
logger.info("Keeping %i results" % result_list.num_results())

end_time = time.time()
time_elapsed = end_time - start_time
print("{:.2f}s elapsed".format(time_elapsed))
logger.info("{:.2f}s elapsed".format(time_elapsed))

def apply_clustering(self, result_list, cluster_params):
"""This function clusters results that have similar trajectories.
Expand All @@ -279,7 +276,7 @@ def apply_clustering(self, result_list, cluster_params):
# Skip clustering if there is nothing to cluster.
if result_list.num_results() == 0:
return
print("Clustering %i results" % result_list.num_results(), flush=True)
logger.info("Clustering %i results" % result_list.num_results())

# Do the clustering and the filtering.
f = DBSCANFilter(
Expand Down
30 changes: 12 additions & 18 deletions src/kbmod/data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from kbmod.work_unit import WorkUnit


logger = kb.Logging.getLogger(__name__)


def load_input_from_individual_files(
im_filepath,
time_file,
Expand Down Expand Up @@ -47,21 +50,17 @@ def load_input_from_individual_files(
visit_times : `list`
A list of MJD times.
"""
print("---------------------------------------")
print("Loading Images")
print("---------------------------------------")
logger.info("Loading Images")

# Load a mapping from visit numbers to the visit times. This dictionary stays
# empty if no time file is specified.
image_time_dict = FileUtils.load_time_dictionary(time_file)
if verbose:
print(f"Loaded {len(image_time_dict)} time stamps.")
logger.info(f"Loaded {len(image_time_dict)} time stamps.")

# Load a mapping from visit numbers to PSFs. This dictionary stays
# empty if no time file is specified.
image_psf_dict = FileUtils.load_psf_dictionary(psf_file)
if verbose:
print(f"Loaded {len(image_psf_dict)} image PSFs stamps.")
logger.info(f"Loaded {len(image_psf_dict)} image PSFs stamps.")

# Retrieve the list of visits (file names) in the data directory.
patch_visits = sorted(os.listdir(im_filepath))
Expand All @@ -73,8 +72,7 @@ def load_input_from_individual_files(
for visit_file in np.sort(patch_visits):
# Skip non-fits files.
if not ".fits" in visit_file:
if verbose:
print(f"Skipping non-FITS file {visit_file}")
logger.info(f"Skipping non-FITS file {visit_file}")
continue

# Compute the full file path for loading.
Expand All @@ -95,8 +93,7 @@ def load_input_from_individual_files(

# Skip files without a valid visit ID.
if visit_id is None:
if verbose:
print(f"WARNING: Unable to extract visit ID for {visit_file}.")
logger.warning(f"WARNING: Unable to extract visit ID for {visit_file}.")
continue

# Check if the image has a specific PSF.
Expand All @@ -105,8 +102,7 @@ def load_input_from_individual_files(
psf = kb.PSF(image_psf_dict[visit_id])

# Load the image file and set its time.
if verbose:
print(f"Loading file: {full_file_path}")
logger.info(f"Loading file: {full_file_path}")
img = kb.LayeredImage(full_file_path, psf)
time_stamp = img.get_obstime()

Expand All @@ -116,22 +112,20 @@ def load_input_from_individual_files(
img.set_obstime(time_stamp)

if time_stamp <= 0.0:
if verbose:
print(f"WARNING: No valid timestamp provided for {visit_file}.")
logger.warning(f"WARNING: No valid timestamp provided for {visit_file}.")
continue

# Check if we should filter the record based on the time bounds.
if mjd_lims is not None and (time_stamp < mjd_lims[0] or time_stamp > mjd_lims[1]):
if verbose:
print(f"Pruning file {visit_file} by timestamp={time_stamp}.")
logger.info(f"Pruning file {visit_file} by timestamp={time_stamp}.")
continue

# Save image, time, and WCS information.
visit_times.append(time_stamp)
images.append(img)
wcs_list.append(curr_wcs)

print(f"Loaded {len(images)} images")
logger.info(f"Loaded {len(images)} images")
stack = kb.ImageStack(images)

return (stack, wcs_list, visit_times)
Expand Down
50 changes: 30 additions & 20 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os
import time
import warnings

import astropy.coordinates as astroCoords
import astropy.units as u
import koffi
import numpy as np
from astropy.coordinates import solar_system_ephemeris
from astropy.time import Time
from numpy.linalg import lstsq

import kbmod.search as kb

Expand All @@ -21,6 +17,9 @@
from .work_unit import WorkUnit


logger = kb.Logging.getLogger(__name__)


class SearchRunner:
"""A class to run the KBMOD grid search."""

Expand Down Expand Up @@ -76,15 +75,18 @@ def do_gpu_search(self, config, search):
search.set_start_bounds_y(-config["y_pixel_buffer"], height + config["y_pixel_buffer"])

search_start = time.time()
print("Starting Search")
print("---------------------------------------")
print(f"Average Angle = {config['average_angle']}")
print(f"Search Angle Limits = {ang_lim}")
print(f"Velocity Limits = {config['v_arr']}")
logger.info(f"Starting search with config: {config}")
#print("")
#print("Starting Search")
#print("---------------------------------------")
#print(f"Average Angle = {config['average_angle']}")
#print(f"Search Angle Limits = {ang_lim}")
#print(f"Velocity Limits = {config['v_arr']}")

# If we are using gpu_filtering, enable it and set the parameters.
if config["gpu_filter"]:
print("Using in-line GPU sigmaG filtering methods", flush=True)
logger.info("Using in-line GPU sigmG filtering.")
#print("Using in-line GPU sigmaG filtering methods", flush=True)
coeff = SigmaGClipping.find_sigma_g_coeff(
config["sigmaG_lims"][0],
config["sigmaG_lims"][1],
Expand Down Expand Up @@ -114,7 +116,8 @@ def do_gpu_search(self, config, search):
int(config["num_obs"]),
)

print("Search finished in {0:.3f}s".format(time.time() - search_start), flush=True)
logger.info(f"Search finished in {time.time()-search_start:0.3f} seconds.")
#print("Search finished in {0:.3f}s".format(time.time() - search_start), flush=True)
return search

def run_search(self, config, stack):
Expand Down Expand Up @@ -190,7 +193,8 @@ def run_search(self, config, stack):
# _count_known_matches(keep, search)

# Save the results and the configuration information used.
print(f"Found {keep.num_results()} potential trajectories.")
# print(f"Found {keep.num_results()} potential trajectories.")
logger.info(f"Found {keep.num_results()} potential trajectories.")
if config["res_filepath"] is not None and config["ind_output_files"]:
keep.save_to_files(config["res_filepath"], config["output_suffix"])

Expand All @@ -200,7 +204,8 @@ def run_search(self, config, stack):
keep.write_table(config["result_filename"])

end = time.time()
print("Time taken for patch: ", end - start)
logger.info("Time taken for patch: {time.time() - start:0.3f}")
#print("Time taken for patch: ", end - start)

return keep

Expand All @@ -223,7 +228,8 @@ def run_search_from_work_unit(self, work):
if work.get_wcs(0) is not None:
work.config.set("average_angle", self._calc_suggested_angle(work.get_wcs(0), center_pixel))
else:
print("WARNING: average_angle is unset and no WCS provided. Using 0.0.")
logger.warning("Average angle not set and no WCS provided. Setting average_angle=0.0")
#print("WARNING: average_angle is unset and no WCS provided. Using 0.0.")
work.config.set("average_angle", 0.0)

# Run the search.
Expand Down Expand Up @@ -297,20 +303,22 @@ def _count_known_matches(self, result_list, search):
ps.build_from_images_and_xy_positions(PixelPositions, metadata)
ps_list.append(ps)

print("-----------------")
#print("-----------------")
matches = {}
known_obj_thresh = config["known_obj_thresh"]
min_obs = config["known_obj_obs"]
if config["known_obj_jpl"]:
print("Quering known objects from JPL")
logger.info("Querying known objects from JPL.")
#print("Quering known objects from JPL")
matches = koffi.jpl_query_known_objects_stack(
potential_sources=ps_list,
images=metadata,
min_observations=min_obs,
tolerance=known_obj_thresh,
)
else:
print("Quering known objects from SkyBoT")
logger.info("Querying known objects from SkyBoT.")
#print("Quering known objects from SkyBoT")
matches = koffi.skybot_query_known_objects_stack(
potential_sources=ps_list,
images=metadata,
Expand All @@ -324,11 +332,13 @@ def _count_known_matches(self, result_list, search):
if len(matches[ps_id]) > 0:
num_found += 1
matches_string += f"result id {ps_id}:" + str(matches[ps_id])[1:-1] + "\n"
print("Found %i objects with at least %i potential observations." % (num_found, config["num_obs"]))
logger.info(f"Found {num_found} objects with at least {config['num_obs']} potential observations.")
#print("Found %i objects with at least %i potential observations." % (num_found, config["num_obs"]))

if num_found > 0:
print(matches_string)
print("-----------------")
logger.info(matches_string)
#print(matches_string)
#print("-----------------")

def _calc_suggested_angle(self, wcs, center_pixel=(1000, 2000), step=12):
"""Projects an unit-vector parallel with the ecliptic onto the image
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/search/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace py = pybind11;

#include "logging.h"
#include "common.h"
#include "geom.h"

Expand All @@ -17,7 +18,6 @@ namespace py = pybind11;
#include "stamp_creator.cpp"
#include "kernel_testing_helpers.cpp"
#include "psi_phi_array.cpp"
#include "logging.cpp"


PYBIND11_MODULE(search, m) {
Expand Down
Loading

0 comments on commit 6f35a95

Please sign in to comment.