-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
request list and other adaptions for more efficient experience #110
base: master
Are you sure you want to change the base?
Changes from all commits
2774756
21abb44
2ab4eaa
9681f8c
0a40496
86c9ed6
fa25c2f
bd93ae0
f83e55d
b2b0889
6c02f18
6da79bd
5b15bff
d77e407
fe27247
9bdda5d
904be57
3ed9036
5ce4322
cc7f2dd
c91e8bc
36f81d2
6a5b93a
c329aa3
9d84cca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,4 +43,5 @@ vlf_table.fits | |
tests/.oda-token | ||
|
||
.ipynb_checkpoints | ||
.venv | ||
cache | ||
.venv |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
from collections import OrderedDict | ||
from json.decoder import JSONDecodeError | ||
from typing import Any, Callable, List, Tuple, Dict | ||
from astropy.table import Table | ||
from astropy.coordinates import Angle | ||
|
||
|
@@ -15,11 +16,11 @@ | |
|
||
from .data_products import NumpyDataProduct, BinaryData, ApiCatalog, GWContoursDataProduct | ||
|
||
from builtins import (bytes, str, open, super, range, | ||
zip, round, input, int, pow, object, map, zip) | ||
|
||
__author__ = "Andrea Tramacere, Volodymyr Savchenko" | ||
|
||
import hashlib | ||
import gzip | ||
import warnings | ||
import requests | ||
import ast | ||
|
@@ -195,6 +196,9 @@ class DispatcherAPI: | |
# allowing token discovery by default changes the user interface in some cases, | ||
# but in desirable way | ||
token_discovery_methods = None | ||
use_local_cache = False | ||
skip_parameter_check = False | ||
raise_on_failure = True | ||
|
||
def __init__(self, | ||
instrument='mock', | ||
|
@@ -374,6 +378,12 @@ def selected_request_method(self): | |
return self.preferred_request_method | ||
|
||
def request_to_json(self, verbose=False): | ||
try: | ||
return self.load_result() | ||
except Exception as e: | ||
logger.debug('unable to load result from %s: will need to compute', self.unique_response_json_fn) | ||
|
||
|
||
self.progress_logger.info( | ||
f'- waiting for remote response (since {time.strftime("%Y-%m-%d %H:%M:%S")}), please wait for {self.url}/{self.run_analysis_handle}') | ||
|
||
|
@@ -458,6 +468,9 @@ def request_to_json(self, verbose=False): | |
|
||
self.returned_analysis_parameters = response_json['products'].get('analysis_parameters', None) | ||
|
||
if response_json.get('query_status') in ['done', 'failed']: | ||
self.save_result(response_json) | ||
|
||
return response_json | ||
except json.decoder.JSONDecodeError as e: | ||
self.logger.error( | ||
|
@@ -662,7 +675,8 @@ def poll(self, verbose=None, silent=None): | |
if self.is_complete: | ||
# TODO: something raising here does not help | ||
self.logger.debug("poll returing data: complete") | ||
return DataCollection.from_response_json(self.response_json, self.instrument, self.product) | ||
self.stored_result = DataCollection.from_response_json(self.response_json, self.instrument, self.product) | ||
return self.stored_result | ||
|
||
def show_progress(self): | ||
full_report_dict_list = self.response_json['job_monitor'].get( | ||
|
@@ -741,10 +755,14 @@ def process_failure(self): | |
self.response_json['exit_status']['message']) | ||
logger.error("have exception message: keys \"%s\"", | ||
exception_by_message.keys()) | ||
raise exception_by_message.get(self.response_json['exit_status']['message'], RemoteException)( | ||
message=self.response_json['exit_status']['message'], | ||
debug_message=self.response_json['exit_status']['error_message'] | ||
) | ||
|
||
if self.raise_on_failure: | ||
raise exception_by_message.get(self.response_json['exit_status']['message'], RemoteException)( | ||
message=self.response_json['exit_status']['message'], | ||
debug_message=self.response_json['exit_status']['error_message'] | ||
) | ||
else: | ||
self.exception_json = self.response_json['exit_status'] | ||
|
||
def failure_report(self, res_json): | ||
self.logger.error('query failed!') | ||
|
@@ -874,6 +892,42 @@ def report_last_request(self): | |
self.logger.info( | ||
f"{C.GREY}last request completed in {self.last_request_t_complete - self.last_request_t0} seconds{C.NC}") | ||
|
||
|
||
def parameter_check(self, instrument, product, kwargs): | ||
|
||
res = requests.get("%s/api/par-names" % self.url, params=dict( | ||
instrument=instrument, product_type=product), cookies=self.cookies) | ||
|
||
if res.status_code != 200: | ||
warnings.warn( | ||
'parameter check not available on remote server, check carefully parameters name') | ||
else: | ||
_ignore_list = ['instrument', 'product_type', 'query_type', | ||
'off_line', 'query_status', 'verbose', 'session_id'] | ||
validation_dict = copy.deepcopy(kwargs) | ||
|
||
for _i in _ignore_list: | ||
del validation_dict[_i] | ||
|
||
valid_names = self._decode_res_json(res) | ||
for n in validation_dict.keys(): | ||
if n not in valid_names: | ||
if self.strict_parameter_check: | ||
raise UserError(f'the parameter: {n} is not among the valid ones: {valid_names}' | ||
f'(you can set {self}.strict_parameter_check=False, but beware!') | ||
else: | ||
msg = '\n' | ||
msg += '----------------------------------------------------------------------------\n' | ||
msg += 'the parameter: %s ' % n | ||
msg += ' is not among valid ones:' | ||
msg += '\n' | ||
msg += '%s' % valid_names | ||
msg += '\n' | ||
msg += 'this will throw an error in a future version \n' | ||
msg += 'and might break the current request!\n ' | ||
msg += '----------------------------------------------------------------------------\n' | ||
warnings.warn(msg) | ||
|
||
def get_list_terms_gallery(self, | ||
group: str = None, | ||
parent: str = None, | ||
|
@@ -1141,38 +1195,8 @@ def get_product(self, | |
'However the oda_api will perform a check of the list of valid parameters for your request.') | ||
del kwargs['dry_run'] | ||
|
||
res = requests.get("%s/api/par-names" % self.url, params=dict( | ||
instrument=instrument, product_type=product), cookies=self.cookies) | ||
|
||
if res.status_code != 200: | ||
warnings.warn( | ||
'parameter check not available on remote server, check carefully parameters name') | ||
else: | ||
_ignore_list = ['instrument', 'product_type', 'query_type', | ||
'off_line', 'query_status', 'verbose', 'session_id'] | ||
validation_dict = copy.deepcopy(kwargs) | ||
|
||
for _i in _ignore_list: | ||
del validation_dict[_i] | ||
|
||
valid_names = self._decode_res_json(res) | ||
for n in validation_dict.keys(): | ||
if n not in valid_names: | ||
if self.strict_parameter_check: | ||
raise UserError(f'the parameter: {n} is not among the valid ones: {valid_names}' | ||
f'(you can set {self}.strict_parameter_check=False, but beware!') | ||
else: | ||
msg = '\n' | ||
msg += '----------------------------------------------------------------------------\n' | ||
msg += 'the parameter: %s ' % n | ||
msg += ' is not among valid ones:' | ||
msg += '\n' | ||
msg += '%s' % valid_names | ||
msg += '\n' | ||
msg += 'this will throw an error in a future version \n' | ||
msg += 'and might break the current request!\n ' | ||
msg += '----------------------------------------------------------------------------\n' | ||
warnings.warn(msg) | ||
if not self.skip_parameter_check: | ||
self.parameter_check(instrument, product, kwargs) | ||
|
||
if kwargs.get('token', None) is None and self.token_discovery_methods is not None: | ||
discovered_token = oda_api.token.discover_token(self.token_discovery_methods) | ||
|
@@ -1202,12 +1226,10 @@ def get_product(self, | |
d = DataCollection.from_response_json( | ||
res_json, instrument, product) | ||
|
||
del (res) | ||
|
||
return d | ||
|
||
@staticmethod | ||
def set_api_code(query_dict, url="www.astro.unige.ch/mmoda/dispatch-data"): | ||
def set_api_code(query_dict, url="www.astro.unige.ch/mmoda/dispatch-data") -> str: | ||
|
||
query_dict = OrderedDict(sorted(query_dict.items())) | ||
|
||
|
@@ -1247,14 +1269,45 @@ def __repr__(self): | |
return f"[ {self.__class__.__name__}: {self.url} ]" | ||
|
||
|
||
class DataCollection(object): | ||
def save_result(self, response_json): | ||
fn = self.unique_response_json_fn | ||
|
||
os.makedirs(os.path.dirname(fn), exist_ok=True) | ||
|
||
json.dump(response_json, gzip.open(fn, "wt")) | ||
logger.info('saved result in %s', fn) | ||
|
||
|
||
def load_result(self): | ||
fn = self.unique_response_json_fn | ||
logger.info('trying to load result from %s', fn) | ||
|
||
t0 = time.time() | ||
r = json.load(gzip.open(fn, 'rb')) | ||
|
||
logger.info('\033[32mmanaged to load result\033[0m from %s in %.2f seconds', fn, time.time() - t0) | ||
return r | ||
|
||
|
||
@property | ||
def unique_response_json_fn(self): | ||
request_hash = hashlib.md5(self.set_api_code(self.parameters_dict).encode()).hexdigest()[:16] | ||
|
||
return f"cache/oda_api_data_collection_{request_hash}.json.gz" | ||
|
||
|
||
|
||
class DataCollection(object): | ||
|
||
def __init__(self, data_list, add_meta_to_name=['src_name', 'product'], instrument=None, product=None): | ||
self._p_list = [] | ||
self._n_list = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. duplication
Comment on lines
+1302
to
+1304
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why having two |
||
def __init__(self, data_list, add_meta_to_name=['src_name', 'product'], instrument=None, product=None, request_job_id=None): | ||
self._p_list = [] | ||
self._n_list = [] | ||
self.request_job_id = request_job_id | ||
|
||
for ID, data in enumerate(data_list): | ||
|
||
name = '' | ||
if hasattr(data, 'name'): | ||
name = data.name | ||
|
@@ -1296,10 +1349,42 @@ def as_list(self): | |
meta_data = '' | ||
|
||
L.append({ | ||
'ID': ID, 'prod_name': prod_name, 'meta_data:': meta_data | ||
'ID': ID, | ||
'prod_name': prod_name, | ||
'metadata': meta_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why both |
||
'meta_data:': meta_data # ??? | ||
}) | ||
|
||
return L | ||
return L | ||
|
||
|
||
@property | ||
def product_indexer(self): | ||
return getattr(self, | ||
'_product_indexer', | ||
lambda p:(p['metadata']['src_name'], p['metadata']['product']) | ||
) | ||
|
||
|
||
@product_indexer.setter | ||
def product_indexer(self, v): | ||
self._product_indexer = v | ||
if len(self.keys()) != len(self.as_list()): | ||
raise RuntimeError("duplicate index in metadata, this should be impossible, please check if product_indexer if good!") | ||
|
||
|
||
def __getitem__(self, key): | ||
return getattr(self, self.as_dict()[key]['prod_name']) | ||
|
||
|
||
def as_dict(self): | ||
return { self.product_indexer(product): product | ||
for product in self.as_list() } | ||
|
||
|
||
def keys(self): | ||
return list(self.as_dict().keys()) | ||
|
||
|
||
def _build_prod_name(self, prod, name, add_meta_to_name): | ||
|
||
|
@@ -1321,10 +1406,13 @@ def save_all_data(self, prenpend_name=None): | |
file_name = file_name + '.fits' | ||
prod.write_fits_file(file_name) | ||
|
||
|
||
def save(self, file_name): | ||
pickle.dump(self, open(file_name, 'wb'), | ||
pickle.dump(self, | ||
gzip.open(file_name, 'wb'), | ||
protocol=pickle.HIGHEST_PROTOCOL) | ||
|
||
|
||
def new_from_metadata(self, key, val): | ||
dc = None | ||
_l = [] | ||
|
@@ -1336,6 +1424,7 @@ def new_from_metadata(self, key, val): | |
dc = DataCollection(_l) | ||
|
||
return dc | ||
|
||
|
||
@classmethod | ||
def from_response_json(cls, res_json, instrument, product): | ||
|
@@ -1401,3 +1490,64 @@ def from_response_json(cls, res_json, instrument, product): | |
p.meta_data = p.meta | ||
|
||
return d | ||
|
||
|
||
class DispatcherAPICollection: | ||
|
||
def __init__(self, wait_between_poll_sequences_s=None, **kwargs) -> None: | ||
self.wait_between_poll_sequences_s = wait_between_poll_sequences_s | ||
self.constructor_kwargs = kwargs | ||
|
||
def get_product_list( | ||
self, | ||
parameter_dict_list: List[Dict[str, Any]], | ||
**kwargs): | ||
|
||
self.client_list = [] | ||
product_list = [] | ||
|
||
for parameter_dict in parameter_dict_list: | ||
disp = DispatcherAPI(**self.constructor_kwargs) | ||
|
||
disp.use_local_cache = True | ||
disp.wait = False | ||
disp.skip_parameter_check = True | ||
disp.raise_on_failure = False | ||
|
||
product_list.append(disp.get_product(**kwargs, **parameter_dict)) | ||
|
||
self.client_list.append(disp) | ||
|
||
logger.info('prepared %s clients, %s are done', | ||
len(self.client_list), | ||
len([c for c in self.client_list if c.is_complete]) | ||
) | ||
|
||
if self.wait_between_poll_sequences_s is None: | ||
logger.info('not waiting for poll, please do not forget to come back for your results!') | ||
else: | ||
while True: | ||
product_list = [] | ||
|
||
for client in self.client_list: | ||
client.poll() | ||
|
||
logger.info('polled %s clients, %s are done', | ||
len(self.client_list), | ||
len([c for c in self.client_list if c.is_complete]), | ||
) | ||
|
||
product_list.append(client.poll()) | ||
|
||
if all([c.is_complete for c in self.client_list]): | ||
logger.info('all done!') | ||
break | ||
else: | ||
logger.info('will sleep %s s', | ||
self.wait_between_poll_sequences_s | ||
) | ||
time.sleep(self.wait_between_poll_sequences_s) | ||
|
||
return product_list | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and below, isn't it better to use context manager to open?