Skip to content

Commit

Permalink
Merge branch 'update-config-schema' into integrate-pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nmvrs committed Jun 6, 2024
2 parents 65004c7 + 2677d68 commit f4f21c6
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 8 deletions.
44 changes: 38 additions & 6 deletions moai/action/archive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import logging
import os
import re
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -60,18 +61,49 @@ def dump_handlers(
handler_config: omegaconf.omegaconf.DictConfig,
):
empty = omegaconf.DictConfig({})
# parse nested lists
# Use regular expression to extract elements within square brackets
pattern = r"\[(.*?)\]"
with open("pre.yaml", "w") as f:
data = omegaconf.OmegaConf.to_container(handlers.get("pre") or empty)
data = list(
map(lambda e: {e.split(":")[0].strip(): e.split(":")[1].strip()}, data)
# data = list(
# map(lambda e: {e.split(":")[0].strip(): e.split(":")[1].strip()}, data)
# )
new_data = []
for e in data:
key, value = e.split(":")
match = re.search(pattern, value)
if match:
# Extract elements within square brackets
data_list = match.group(1).split(",")
# Remove leading/trailing spaces from elements
data_list = [item.strip() for item in data_list]
new_data.append({key.strip(): data_list})
else:
new_data.append({key.strip(): value.strip()})
yaml.dump(
{"defaults": new_data}, f, default_style=None, default_flow_style=False
)
yaml.dump({"defaults": data}, f, default_style=None, default_flow_style=False)
with open("post.yaml", "w") as f:
data = omegaconf.OmegaConf.to_container(handlers.get("post") or empty)
data = list(
map(lambda e: {e.split(":")[0].strip(): e.split(":")[1].strip()}, data)
# data = list(
# map(lambda e: {e.split(":")[0].strip(): e.split(":")[1].strip()}, data)
# )
new_data = []
for e in data:
key, value = e.split(":")
match = re.search(pattern, value)
if match:
# Extract elements within square brackets
data_list = match.group(1).split(",")
# Remove leading/trailing spaces from elements
data_list = [item.strip() for item in data_list]
new_data.append({key.strip(): data_list})
else:
new_data.append({key.strip(): value.strip()})
yaml.dump(
{"defaults": new_data}, f, default_style=None, default_flow_style=False
)
yaml.dump({"defaults": data}, f, default_style=None, default_flow_style=False)
if handler_config is not None:
with open("handler_overrides.yaml", "w") as f:
data = omegaconf.OmegaConf.to_container(handler_config)
Expand Down
4 changes: 4 additions & 0 deletions moai/conf/engine/serve/handlers/input/setup_guid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# @package handlers.preprocess.setup_guid

_target_: moai.serve.handlers.logging_formatter.SetupLogging
input_key: guid
3 changes: 3 additions & 0 deletions moai/conf/engine/serve/handlers/output/reset_guid.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# @package handlers.postprocess.reset_guid

_target_: moai.serve.handlers.logging_formatter.ResetLogging
4 changes: 2 additions & 2 deletions moai/export/local/image2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def save_image(
transform: typing.Union[str, typing.Sequence[str]] = [None],
colormap: typing.Union[str, typing.Sequence[str]] = [None],
modality: typing.List[str] = ["color"],
step: typing.Optional[int] = None,
lightning_step: typing.Optional[int] = None,
# batch_idx: typing.Optional[int]=None,
# optimization_step: typing.Optional[int]=None,
# stage: typing.Optional[str]=None,
Expand All @@ -256,6 +256,6 @@ def save_image(
save_map[modality](
colorize_map[colormap](transform_map[transform](tensors[key].detach())),
key,
step,
lightning_step,
fmt,
)
86 changes: 86 additions & 0 deletions moai/serve/handlers/logging_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Callable

import logging
import typing

import torch

log = logging.getLogger(__name__)

class GuidLogFilter(logging.Filter):
def __init__(self, guid: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.guid = guid

def filter(self, record: logging.LogRecord) -> bool:
record.guid = self.guid
return True


class SetupLogging(Callable):
def __init__(
self,
input_key: str = "guid",
) -> None:
super().__init__()
self.formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s [%(guid)s]: - %(message)s"
)
# formatter = logging.Formatter(
# '[%(asctime)s] %(name)-5s %(levelname)s [%(username)s]: %(message)s',
# datefmt='%Y-%m-%d %H:%M:%S',
# )
self.input_key = input_key

def __call__(
self,
data: typing.Mapping[str, typing.Any],
device: torch.device,
) -> torch.Tensor:
# get handler from log
handler = logging.StreamHandler()
handler.setFormatter(self.formatter)
# get guid from input request
guid = data.get(self.input_key)
# handler.addFilter(logging.Filter(guid))
# log.addHandler(handler)
# Add filter to handler
handler.addFilter(GuidLogFilter(guid=guid))
# Get the root logger
root_logger = logging.getLogger()
root_logger.addHandler(handler)
# Ensure propagation is enabled
root_logger.propagate = True
# log.handlers[0].addFilter(GuidLogFilter(guid=guid))
return data


class ResetLogging(Callable):
def __init__(
self,
) -> None:
super().__init__()

def __call__(
self,
data: typing.Mapping[str, typing.Any],
device: torch.device,
) -> torch.Tensor:
# reset guid from input request
handler = logging.StreamHandler()
handler.setFormatter(
logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s [%(guid)s]: - %(message)s"
)
)
# handler.addFilter(logging.Filter(''))
handler.addFilter(GuidLogFilter(guid=""))
# Get the root logger and remove existing handlers
root_logger = logging.getLogger()
for hdlr in root_logger.handlers[:]:
root_logger.removeHandler(hdlr)

# Add the new handler to the root logger
root_logger.addHandler(handler)
root_logger.propagate = True
return data

0 comments on commit f4f21c6

Please sign in to comment.