Skip to content

Commit

Permalink
Add CLI for analyze element result in a bar figure
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 3, 2023
1 parent 0cccd5e commit 4a39f6c
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
162 changes: 162 additions & 0 deletions aiida_sssp_workflow/cli/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- coding: utf-8 -*-
"""CLI to inspect the results of the workflow"""
import json
import random
from pathlib import Path

import click
Expand All @@ -14,6 +15,100 @@
from aiida_sssp_workflow.utils import HIGH_DUAL_ELEMENTS, get_protocol


def lighten_color(color, amount=0.5):
"""
Lightens the given color by multiplying (1-luminosity) by the given amount.
Input can be matplotlib color string, hex string, or RGB tuple.
REF from https://stackoverflow.com/questions/37765197/darken-or-lighten-a-color-in-matplotlib
Examples:
>> lighten_color('g', 0.3)
>> lighten_color('#F034A3', 0.6)
>> lighten_color((.3,.55,.1), 0.5)
"""
import colorsys

import matplotlib.colors as mc

try:
c = mc.cnames[color]
except Exception:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], 1 - amount * (1 - c[1]), c[2])


def cmap(pseudo_info: dict) -> str:
"""Return RGB string of color for given pseudo info
Hardcoded at the momment.
"""
if pseudo_info["family"] == "sg15" and pseudo_info["version"] == "v0":
return "#000000"

if pseudo_info["family"] == "gbrv":
return "#4682B4"

if (
pseudo_info["family"] == "psl"
and pseudo_info["type"] == "us"
and pseudo_info["version"] == "v1.0.0-high"
):
return "#F50E02"

if (
pseudo_info["family"] == "psl"
and pseudo_info["type"] == "us"
and pseudo_info["version"] == "v1.0.0-low"
):
return lighten_color("#F50E02")

if (
pseudo_info["family"] == "psl"
and pseudo_info["type"] == "paw"
and pseudo_info["version"] == "v1.0.0-high"
):
return "#008B00"

if (
pseudo_info["family"] == "psl"
and pseudo_info["type"] == "paw"
and pseudo_info["version"] == "v1.0.0-low"
):
return lighten_color("#008B00")

if (
pseudo_info["family"] == "psl"
and pseudo_info["type"] == "paw"
and "v0." in pseudo_info["version"]
):
return "#FF00FF"

if (
pseudo_info["family"] == "psl"
and pseudo_info["type"] == "us"
and "v0." in pseudo_info["version"]
):
return lighten_color("#FF00FF")

if pseudo_info["family"] == "dojo" and pseudo_info["version"] == "v4-str":
return "#F9A501"

if pseudo_info["family"] == "dojo" and pseudo_info["version"] == "v4-std":
return lighten_color("#F9A501")

if pseudo_info["family"] == "jth" and pseudo_info["version"] == "v1.1-str":
return "#00C5ED"

if pseudo_info["family"] == "jth" and pseudo_info["version"] == "v1.1-std":
return lighten_color("#00C5ED")

# TODO: more mapping
# if a unknow type generate random color based on ascii sum
ascn = sum([ord(c) for c in pseudo_info["representive_label"]])
random.seed(ascn)
return f"#{random.randint(0, 16777215):06x}"


def birch_murnaghan(V, E0, V0, B0, B01):
"""
Return the energy for given volume (V - it can be a vector) according to
Expand Down Expand Up @@ -130,6 +225,73 @@ def convergence_plot(
ax.set_title(title, fontsize=8)


@cmd_root.command("analyze")
@click.argument("group") # The group to inspect, name or pk of the group
@click.option("--output", "-o", default="output", help="The output file name")
def analyze(group, output):
"""Render the plot for the given pseudos and measure type."""
from aiida_sssp_workflow.utils import ACWF_CONFIGURATIONS, parse_label

# landscape mode
fig, ax = plt.subplots(1, 1, figsize=(11.69, 8.27), dpi=100)

group_node = orm.load_group(group)

# Get all nodes in the output group (EOS workflows)
group_node_query = (
orm.QueryBuilder()
.append(orm.Group, filters={"id": group_node.id}, tag="groups")
.append(orm.Node, project="*", with_group="groups")
)

nodes_lst = group_node_query.all(flat=True)

for i, node in enumerate(nodes_lst):
width = 0.6 / len(nodes_lst)
pseudo_info = parse_label(node.base.extras.all["label"].split(" ")[-1])

# TODO assert element should not change
element = node.outputs.pseudo_info.get_dict()["element"]

y_nu = []
for configuration in ACWF_CONFIGURATIONS:
res = node.outputs.measure.precision[configuration]["output_parameters"]

nu = res["rel_errors_vec_length"]
y_nu.append(nu)

x = np.arange(len(ACWF_CONFIGURATIONS))

ax.bar(
x + width * i,
y_nu,
width,
color=cmap(pseudo_info),
edgecolor="black",
linewidth=1,
label=pseudo_info["representive_label"],
)
ax.set_title(f"X={element}")

ax.axhline(y=0.1, linestyle="--", color="green", label="~0.1 excellent")
ax.axhline(y=0.33, linestyle="--", color="red", label="~0.33 good")
ax.legend(loc="upper left", prop={"size": 10})
ax.set_ylabel("ν -facto")
ax.set_ylim([0, 0.6])
ax.set_xticks(range(len(ACWF_CONFIGURATIONS)))
ax.set_xticklabels(ACWF_CONFIGURATIONS)

# fig to pdf
fig.tight_layout()
fig.suptitle(
f"Comparison on {element}",
fontsize=10,
)
fig.subplots_adjust(top=0.92)
fpath = Path.cwd() / f"{output}_{element}_fight.pdf"
fig.savefig(fpath.name, bbox_inches="tight")


@cmd_root.command("inspect")
@click.argument("node") # The node to inspect, uuid or pk
@click.option("--output", "-o", default="output", help="The output file name")
Expand Down
26 changes: 26 additions & 0 deletions aiida_sssp_workflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def get_protocol(category, name=None):

OXIDE_CONFIGURATIONS = ["XO", "XO2", "XO3", "X2O", "X2O3", "X2O5"]
UNARIE_CONFIGURATIONS = ["BCC", "FCC", "SC", "Diamond"]
ACWF_CONFIGURATIONS = OXIDE_CONFIGURATIONS + UNARIE_CONFIGURATIONS


def update_dict(d, u):
Expand Down Expand Up @@ -230,6 +231,31 @@ def get_standard_structure(
return structure


def parse_label(label):
"""parse standard pseudo label to dict of pseudo info"""
element, type, z, tool, family, *version = label.split(".")
version = ".".join(version)

if type == "nc":
full_type = "NC"
if type == "us":
full_type = "Ultrasoft"
if type == "paw":
full_type = "PAW"

return {
"element": element,
"type": type,
"z": z,
"tool": tool,
"family": family,
"version": version,
"representive_label": f"{z}|{full_type}|{family}|{tool}|{version}",
"concise_label": f"{z}|{type}|{family}|{version}",
"full_label": f"{element}|{z}|{full_type}|{family}|{tool}|{version}",
}


def parse_upf(upf_content: str) -> dict:
"""
Expand Down

0 comments on commit 4a39f6c

Please sign in to comment.