Skip to content

Commit

Permalink
Merge pull request #213 from madgik/fix/nla
Browse files Browse the repository at this point in the history
Fix/nla
  • Loading branch information
KFilippopolitis authored Apr 8, 2022
2 parents d624cd7 + 58cbac8 commit 00e3e56
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 63 deletions.
4 changes: 2 additions & 2 deletions mipengine/controller/data_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ def _have_common_elements(a: List[Any], b: List[Any]):

class DataModelRegistry:
def __init__(self):
self.data_models: Dict[str, CommonDataElements] = {}
self.datasets_location: Dict[str, Dict[str, List[str]]] = {}
self._data_models: Dict[str, CommonDataElements] = {}
self._datasets_location: Dict[str, Dict[str, List[str]]] = {}

@property
def data_models(self):
Expand Down
143 changes: 83 additions & 60 deletions mipengine/controller/node_landscape_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ async def _get_nodes_info(nodes_socket_addr: List[str]) -> List[NodeInfo]:
}

tasks_coroutines = [
_task_to_async(task, connection=app.broker_connection())(
request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID
)
_task_to_async(task, app=app)(request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID)
for app, task in nodes_task_signature.items()
]
results = await asyncio.gather(*tasks_coroutines, return_exceptions=True)
Expand All @@ -55,6 +53,10 @@ async def _get_nodes_info(nodes_socket_addr: List[str]) -> List[NodeInfo]:
for result in results
if not isinstance(result, Exception)
]

for app in celery_apps:
app.close()

return nodes_info


Expand All @@ -64,9 +66,9 @@ async def _get_node_datasets_per_data_model(
celery_app = get_node_celery_app(node_socket_addr)
task_signature = celery_app.signature(GET_NODE_DATASETS_PER_DATA_MODEL_SIGNATURE)

result = await _task_to_async(
task_signature, connection=celery_app.broker_connection()
)(request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID)
result = await _task_to_async(task_signature, app=celery_app)(
request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID
)

datasets_per_data_model = {}
if not isinstance(result, Exception):
Expand All @@ -80,15 +82,15 @@ async def _get_node_cdes(node_socket_addr: str, data_model: str) -> CommonDataEl
celery_app = get_node_celery_app(node_socket_addr)
task_signature = celery_app.signature(GET_DATA_MODEL_CDES_SIGNATURE)

result = await _task_to_async(
task_signature, connection=celery_app.broker_connection()
)(data_model=data_model, request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID)
result = await _task_to_async(task_signature, app=celery_app)(
data_model=data_model, request_id=NODE_LANDSCAPE_AGGREGATOR_REQUEST_ID
)

if not isinstance(result, Exception):
return CommonDataElements.parse_raw(result)


def _task_to_async(task, connection):
def _task_to_async(task, app):
"""ex
Converts a Celery task to an async function
Celery doesn't currently support asyncio "await" while "getting" a result
Expand All @@ -106,18 +108,19 @@ async def wrapper(*args, **kwargs):
delay = 0.1
# Since apply_async is used instead of delay so that we can pass the connection as an argument,
# the args and kwargs need to be passed as named arguments.
async_result = await sync_to_async(task.apply_async)(
args=args, kwargs=kwargs, connection=connection
)
while not async_result.ready():
total_delay += delay
if total_delay > CELERY_TASKS_TIMEOUT:
raise TimeoutError(
f"Celery task: {task} didn't respond in {CELERY_TASKS_TIMEOUT}s."
)
await asyncio.sleep(delay)
delay = min(delay * 1.5, 2) # exponential backoff, max 2 seconds
return async_result.get(timeout=CELERY_TASKS_TIMEOUT - total_delay)
with app.broker_connection() as conn:
async_result = await sync_to_async(task.apply_async)(
args=args, kwargs=kwargs, connection=conn
)
while not async_result.ready():
total_delay += delay
if total_delay > CELERY_TASKS_TIMEOUT:
raise TimeoutError(
f"Celery task: {task} didn't respond in {CELERY_TASKS_TIMEOUT}s."
)
await asyncio.sleep(delay)
delay = min(delay * 1.5, 2) # exponential backoff, max 2 seconds
return async_result.get(timeout=CELERY_TASKS_TIMEOUT - total_delay)

return wrapper

Expand Down Expand Up @@ -157,25 +160,26 @@ async def update(self):
local_nodes = [
node for node in nodes_info if node.role == NodeRole.LOCALNODE
]
datasets_locations = await _get_datasets_locations(local_nodes)
datasets_labels = await _get_datasets_labels(local_nodes)
data_model_cdes_across_nodes = await _get_cdes_across_nodes(local_nodes)
(
dataset_locations,
aggregated_datasets,
) = await _gather_all_dataset_infos(local_nodes)
data_model_cdes_per_node = await _get_cdes_across_nodes(local_nodes)
compatible_data_models = _get_compatible_data_models(
data_model_cdes_across_nodes
data_model_cdes_per_node
)
data_models = get_updated_data_model_with_dataset_enumerations(
compatible_data_models, datasets_labels
_update_data_models_with_aggregated_datasets(
compatible_data_models, aggregated_datasets
)
datasets_locations = _get_dataset_locations_of_compatible_data_models(
compatible_data_models, dataset_locations
)
datasets_locations = {
common_data_model: datasets_locations[common_data_model]
for common_data_model in data_models
}

self._node_registry._nodes = {
self._node_registry.nodes = {
node_info.id: node_info for node_info in nodes_info
}
self._data_model_registry._data_models = data_models
self._data_model_registry._datasets_location = datasets_locations
self._data_model_registry.data_models = compatible_data_models
self._data_model_registry.datasets_location = datasets_locations
logger.debug(f"Nodes:{[node for node in self._node_registry.nodes]}")
except Exception as exc:
logger.error(f"Node Landscape Aggregator exception: {type(exc)}:{exc}")
Expand Down Expand Up @@ -234,42 +238,50 @@ def get_node_specific_datasets(
)


async def _get_datasets_locations(nodes: List[NodeInfo]) -> Dict[str, Dict[str, str]]:
datasets_locations = {}
async def _gather_all_dataset_infos(
nodes: List[NodeInfo],
) -> Tuple[Dict[str, Dict[str, str]], Dict[str, Dict[str, str]]]:
"""
Args:
nodes: The nodes available in the system
Returns:
A tuple with:
1. The location of each dataset.
2. The aggregated datasets, existing in all nodes
"""
dataset_locations = {}
aggregated_datasets = {}

for node_info in nodes:
node_socket_addr = _get_node_socket_addr(node_info)
datasets_per_data_model = await _get_node_datasets_per_data_model(
node_socket_addr
)
for data_model, datasets in datasets_per_data_model.items():
current_datasets = (
datasets_locations[data_model]
if data_model in datasets_locations

current_labels = (
aggregated_datasets[data_model]
if data_model in aggregated_datasets
else {}
)
current_datasets = (
dataset_locations[data_model] if data_model in dataset_locations else {}
)

for dataset in datasets:
current_labels[dataset] = datasets[dataset]

if dataset in current_datasets:
current_datasets[dataset].append(node_info.id)
else:
current_datasets[dataset] = [node_info.id]
datasets_locations[data_model] = current_datasets

return datasets_locations
aggregated_datasets[data_model] = current_labels
dataset_locations[data_model] = current_datasets


async def _get_datasets_labels(nodes: List[NodeInfo]) -> Dict[str, Dict[str, str]]:
datasets_labels = {}
for node_info in nodes:
node_socket_addr = _get_node_socket_addr(node_info)
datasets_per_data_model = await _get_node_datasets_per_data_model(
node_socket_addr
)
for data_model, datasets in datasets_per_data_model.items():
datasets_labels[data_model] = {}
for dataset in datasets:
datasets_labels[data_model][dataset] = datasets[dataset]
return datasets_labels
return dataset_locations, aggregated_datasets


async def _get_cdes_across_nodes(
Expand All @@ -289,6 +301,15 @@ async def _get_cdes_across_nodes(
return nodes_cdes


def _get_dataset_locations_of_compatible_data_models(
compatible_data_models, dataset_locations
):
return {
compatible_data_model: dataset_locations[compatible_data_model]
for compatible_data_model in compatible_data_models
}


def _get_compatible_data_models(
data_model_cdes_across_nodes: Dict[str, List[Tuple[str, CommonDataElements]]]
) -> Dict[str, CommonDataElements]:
Expand Down Expand Up @@ -323,20 +344,22 @@ def _get_compatible_data_models(
return data_models


def get_updated_data_model_with_dataset_enumerations(
def _update_data_models_with_aggregated_datasets(
data_models: Dict[str, CommonDataElements],
datasets_labels: Dict[str, Dict[str, str]],
) -> Dict[str, CommonDataElements]:
aggregated_datasets: Dict[str, Dict[str, str]],
):
"""
Updates each data_model's 'dataset' enumerations with the aggregated datasets
"""
for data_model in data_models:
dataset_cde = data_models[data_model].values["dataset"]
new_dataset_cde = CommonDataElement(
code=dataset_cde.code,
label=dataset_cde.label,
sql_type=dataset_cde.sql_type,
is_categorical=dataset_cde.is_categorical,
enumerations=datasets_labels[data_model],
enumerations=aggregated_datasets[data_model],
min=dataset_cde.min,
max=dataset_cde.max,
)
data_models[data_model].values["dataset"] = new_dataset_cde
return data_models
2 changes: 1 addition & 1 deletion mipengine/controller/node_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class NodeRegistry:
def __init__(self):
self.nodes: Dict[str, NodeInfo] = {}
self._nodes: Dict[str, NodeInfo] = {}

@property
def nodes(self) -> Dict[str, NodeInfo]:
Expand Down

0 comments on commit 00e3e56

Please sign in to comment.