Skip to content

Commit

Permalink
Feat: Modify common graph properties (#184)
Browse files Browse the repository at this point in the history
* Add option to set height and width
* Set a default width and height
* Add option to select x and y range
* Rename name to label
* Remove option to create two plots with single command
  Instead of
  >>> plot((x1, y1), (x2, y2))
  you should use
  >>> plot(x1, y1) + plot(x2, y2)
  • Loading branch information
martin-schlipf authored Feb 4, 2025
1 parent 7f62186 commit 7cb5bf7
Show file tree
Hide file tree
Showing 18 changed files with 156 additions and 114 deletions.
2 changes: 1 addition & 1 deletion src/py4vasp/_calculation/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _init_selection_dict(self):
def _make_series(self, yaxes, tree):
steps = np.arange(len(self._raw_data.values))[self._slice] + 1
return [
graph.Series(x=steps, y=values, name=label, y2=yaxes.use_y2(label))
graph.Series(x=steps, y=values, label=label, y2=yaxes.use_y2(label))
for label, values in self._read_data(tree, self._slice)
]

Expand Down
2 changes: 1 addition & 1 deletion src/py4vasp/_calculation/pair_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _init_pair_correlation_dict(self):
def _make_series(self, selected_data):
distances = selected_data["distances"]
return [
graph.Series(x=distances, y=data, name=label)
graph.Series(x=distances, y=data, label=label)
for label, data in selected_data.items()
if label != "distances"
]
24 changes: 19 additions & 5 deletions src/py4vasp/_third_party/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,20 @@ class Graph(Sequence):
"One or more series shown in the graph."
xlabel: str = None
"Label for the x axis."
xrange: tuple = None
"Reduce the x axis to this interval."
xticks: dict = None
"A dictionary specifying positions and labels where ticks are placed on the x axis."
xsize: int = 720
"Width of the resulting figure."
ylabel: str = None
"Label for the y axis."
yrange: tuple = None
"Reduce the y axis to this interval."
y2label: str = None
"Label for the secondary y axis."
ysize: int = 540
"Height of the resulting figure."
title: str = None
"Title of the graph."
_frozen = False
Expand Down Expand Up @@ -103,8 +111,8 @@ def label(self, new_label):

def _make_label(self, series, new_label):
if len(self) > 1:
new_label = f"{new_label} {series.name}"
return replace(series, name=new_label)
new_label = f"{new_label} {series.label}"
return replace(series, label=new_label)

def _ipython_display_(self):
self.to_plotly()._ipython_display_()
Expand All @@ -120,6 +128,9 @@ def _make_plotly_figure(self):
self._set_xaxis_options(figure)
self._set_yaxis_options(figure)
figure.layout.title.text = self.title
if self.xsize:
figure.layout.width = self.xsize
figure.layout.height = self.ysize
figure.layout.legend.itemsizing = "constant"
return figure

Expand Down Expand Up @@ -147,6 +158,8 @@ def _set_xaxis_options(self, figure):
figure.layout.xaxis.tickmode = "array"
figure.layout.xaxis.tickvals = tuple(self.xticks.keys())
figure.layout.xaxis.ticktext = self._xtick_labels()
if self.xrange:
figure.layout.xaxis.range = self.xrange
if self._all_are_contour():
figure.layout.xaxis.visible = False

Expand All @@ -163,11 +176,12 @@ def _set_yaxis_options(self, figure):
figure.layout.yaxis.title.text = self.ylabel
if self.y2label:
figure.layout.yaxis2.title.text = self.y2label
if self.yrange:
figure.layout.yaxis.range = self.yrange
if self._all_are_contour():
figure.layout.yaxis.visible = False
if self._any_are_contour():
figure.layout.yaxis.scaleanchor = "x"
figure.layout.height = 500

def _all_are_contour(self):
return all(isinstance(series, Contour) for series in self)
Expand Down Expand Up @@ -220,8 +234,8 @@ def _create_and_populate_df(self, series):
return df

def _name_column(self, series, suffix, idx=None):
if series.name:
text_suffix = series.name.replace(" ", "_") + f".{suffix}"
if series.label:
text_suffix = series.label.replace(" ", "_") + f".{suffix}"
else:
text_suffix = "series_" + str(uuid.uuid1())
if series.y.ndim == 1 or idx is None:
Expand Down
36 changes: 17 additions & 19 deletions src/py4vasp/_third_party/graph/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from py4vasp._third_party.graph.series import Series


def plot(*args, **kwargs):
def plot(x, y, label=None, **kwargs):
"""Plot the given data, modifying the look with some optional arguments.
The intent of this function is not to provide a full fledged plotting functionality
Expand All @@ -13,6 +13,17 @@ def plot(*args, **kwargs):
minimal interface. Use a proper plotting library (e.g. matplotlib or plotly) to
realize more advanced plots.
Parameters
----------
x : np.ndarray
The x values of the coordinates.
y : np.ndarray
The y values of the coordinates.
label : str
If set this will be used to label the series.
**kwargs
All additional arguments will be passed to initialize Series and Graph.
Returns
-------
Graph
Expand All @@ -26,30 +37,17 @@ def plot(*args, **kwargs):
Plot two series in the same graph
>>> plot((x1, y1), (x2, y2))
>>> plot(x1, y1) + plot(x2, y2)
Attributes of the graph are modified by keyword arguments
>>> plot(x, y, xlabel="xaxis", ylabel="yaxis")
"""
series = _parse_series(x, y, label, **kwargs)
for_graph = {key: val for key, val in kwargs.items() if key in Graph._fields}
return Graph(_parse_series(*args, **kwargs), **for_graph)


def _parse_series(*args, **kwargs):
if series := _parse_multiple_series(*args, **kwargs):
return series
else:
return _parse_single_series(*args, **kwargs)


def _parse_multiple_series(*args, **kwargs):
try:
return [Series(*arg) for arg in args]
except TypeError:
return []
return Graph(series, **for_graph)


def _parse_single_series(*args, **kwargs):
def _parse_series(x, y, label, **kwargs):
for_series = {key: val for key, val in kwargs.items() if key in Series._fields}
return Series(*args, **for_series)
return Series(x, y, label, **for_series)
16 changes: 12 additions & 4 deletions src/py4vasp/_third_party/graph/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class Series(trace.Trace):
"""Represents a single series in a graph.
Typically this corresponds to a single line of x-y data with an optional name used
Typically this corresponds to a single line of x-y data with an optional label used
in the legend of the figure. The look of the series is modified by some of the other
optional arguments.
"""
Expand All @@ -25,7 +25,7 @@ class Series(trace.Trace):
y: np.ndarray
"""The y coordinates of the series. If the data is 2-dimensional multiple lines are
generated with a common entry in the legend."""
name: str = None
label: str = None
"A label for the series used in the legend."
width: np.ndarray = None
"When a width is set, the series will be visualized as an area instead of a line."
Expand Down Expand Up @@ -56,6 +56,14 @@ def __setattr__(self, key, value):
assert not self._frozen or hasattr(self, key)
super().__setattr__(key, value)

def __eq__(self, other):
if not isinstance(other, Series):
return NotImplemented
return all(
np.array_equal(getattr(self, field.name), getattr(other, field.name))
for field in fields(self)
)

def to_plotly(self):
first_trace = True
for item in enumerate(np.atleast_2d(np.array(self.y))):
Expand Down Expand Up @@ -118,8 +126,8 @@ def _options_points(self, y, width, first_trace):

def _common_options(self, first_trace):
return {
"name": self.name,
"legendgroup": self.name,
"name": self.label,
"legendgroup": self.label,
"showlegend": first_trace,
"yaxis": "y2" if self.y2 else "y",
}
Expand Down
12 changes: 6 additions & 6 deletions tests/calculation/test_band.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,16 +236,16 @@ def test_spin_projectors_plot(spin_projectors, Assert):
width = 0.05
fig = spin_projectors.plot("O", width)
assert len(fig.series) == 2
assert fig.series[0].name == "O_up"
assert fig.series[0].label == "O_up"
check_data(fig.series[0], width, reference.bands_up, reference.O_up, Assert)
assert fig.series[1].name == "O_down"
assert fig.series[1].label == "O_down"
check_data(fig.series[1], width, reference.bands_down, reference.O_down, Assert)


def check_figure(fig, width, reference, Assert):
assert len(fig.series) == 2
assert fig.series[0].name == "Sr"
assert fig.series[1].name == "p"
assert fig.series[0].label == "Sr"
assert fig.series[1].label == "p"
check_data(fig.series[0], width, reference.bands, reference.Sr, Assert)
check_data(fig.series[1], width, reference.bands, reference.p, Assert)

Expand All @@ -260,9 +260,9 @@ def check_data(series, width, band, projection, Assert):
def test_spin_polarized_plot(spin_polarized, Assert):
fig = spin_polarized.plot()
assert len(fig.series) == 2
assert fig.series[0].name == "up"
assert fig.series[0].label == "up"
Assert.allclose(fig.series[0].y, spin_polarized.ref.bands_up.T)
assert fig.series[1].name == "down"
assert fig.series[1].label == "down"
Assert.allclose(fig.series[1].y, spin_polarized.ref.bands_down.T)


Expand Down
14 changes: 7 additions & 7 deletions tests/calculation/test_bandgap.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,19 @@ def check_default_graph(bandgap, steps, Assert, graph):
assert graph.ylabel == "bandgap (eV)"
assert len(graph.series) == 2
fundamental = graph.series[0]
assert fundamental.name == "fundamental"
assert fundamental.label == "fundamental"
Assert.allclose(fundamental.x, xx)
Assert.allclose(fundamental.y, bandgap.ref.fundamental[steps, 0])
direct = graph.series[1]
assert direct.name == "direct"
assert direct.label == "direct"
Assert.allclose(direct.x, xx)
Assert.allclose(direct.y, bandgap.ref.direct[steps, 0])


def test_plot_selection_default(bandgap, steps, Assert):
graph = bandgap.plot("direct") if steps == -1 else bandgap[steps].plot("direct")
assert len(graph.series) == 1
assert graph.series[0].name == "direct"
assert graph.series[0].label == "direct"
Assert.allclose(graph.series[0].y, bandgap.ref.direct[steps, 0])


Expand All @@ -146,16 +146,16 @@ def test_plot_selection_spin_polarized(spin_polarized, steps, Assert):
graph = bandgap.plot(selection)
assert len(graph.series) == 4
fundamental_up = graph.series[0]
assert fundamental_up.name == "fundamental_up"
assert fundamental_up.label == "fundamental_up"
Assert.allclose(fundamental_up.y, bandgap.ref.fundamental[steps, 1])
direct_up = graph.series[1]
assert direct_up.name == "direct_up"
assert direct_up.label == "direct_up"
Assert.allclose(direct_up.y, bandgap.ref.direct[steps, 1])
fundamental_down = graph.series[2]
assert fundamental_down.name == "fundamental_down"
assert fundamental_down.label == "fundamental_down"
Assert.allclose(fundamental_down.y, bandgap.ref.fundamental[steps, 2])
direct = graph.series[3]
assert direct.name == "direct"
assert direct.label == "direct"
Assert.allclose(direct.y, bandgap.ref.direct[steps, 0])


Expand Down
Loading

0 comments on commit 7cb5bf7

Please sign in to comment.