Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a download image button for different formats #1056

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ def render(self):
)
self.download_button.on_click(self._model.download_data)

self.download_image = ipw.Button(
description="Download image",
button_style="primary",
icon="fa-image",
)
self.download_image.on_click(self._model.download_image)
self.image_format = ipw.Dropdown(
description="Format:",
layout=ipw.Layout(width="auto"),
)
ipw.dlink((self._model, "image_format_options"), (self.image_format, "options"))
ipw.link((self._model, "image_format"), (self.image_format, "value"))

self.download_buttons = ipw.HBox(
children=[self.download_button, self.download_image, self.image_format]
)
self.project_bands_box = ipw.Checkbox(
description="Add `fat bands` projections",
style={"description_width": "initial"},
Expand Down Expand Up @@ -240,7 +256,7 @@ def render(self):
</div>
"""),
self.pdos_options,
self.download_button,
self.download_buttons,
self.legend_interaction_description,
]

Expand Down
41 changes: 41 additions & 0 deletions src/aiidalab_qe/common/bands_pdos/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class BandsPdosModel(Model):
bands_data = {}
bands_projections_data = {}

# Image format options
image_format_options = tl.List(
trait=tl.Unicode(), default_value=["png", "jpeg", "svg", "pdf"]
)
image_format = tl.Unicode("png")

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -284,6 +290,41 @@ def update_trace_color(self, color):
# Update the color picker to match the updated trace
self.color_picker = rgba_to_hex(self.plot.data[self.trace].line.color)

def download_image(self, _=None):
"""
Downloads the current plot as an image in the format specified by self.image_format.
"""
# Define the filename
if self.bands and self.pdos:
filename = f"bands_pdos.{self.image_format}"
else:
filename = f"{'bands' if self.bands else 'pdos'}.{self.image_format}"

# Generate the image in the specified format
image_payload = self.plot.to_image(format=self.image_format)
image_payload_base64 = base64.b64encode(image_payload).decode("utf-8")

self._download_image(payload=image_payload_base64, filename=filename)

@staticmethod
def _download_image(payload, filename):
from IPython.display import Javascript

# Safely format the JavaScript code
javas = Javascript(
"""
var link = document.createElement('a');
link.href = 'data:image/{format};base64,{payload}';
link.download = "{filename}";
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
""".format(
payload=payload, filename=filename, format=filename.split(".")[-1]
)
)
edan-bainglass marked this conversation as resolved.
Show resolved Hide resolved
display(javas)

def download_data(self, _=None):
"""Function to download the data."""
if self.bands_data:
Expand Down
Loading