Skip to content

Commit

Permalink
frame definition in Density Plot
Browse files Browse the repository at this point in the history
  • Loading branch information
danielpastor97 committed Jul 25, 2024
1 parent eb8e36c commit 78c903e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
44 changes: 24 additions & 20 deletions prolint2/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(
self.universe = universe
self.metric = metric

def create_plot(self, res_ids=None, lipid_type=None, metric_name=None, **kwargs):
def create_plot(self, res_ids=None, lipid_type=None, metric_name=None, y_lim=None, **kwargs):
"""
Plot the distribution of a metric for each residue.
Expand Down Expand Up @@ -282,6 +282,9 @@ def create_plot(self, res_ids=None, lipid_type=None, metric_name=None, **kwargs)

ax.tick_params(axis="both", which="major", labelsize=12)

if y_lim is not None:
ax.set_ylim(y_lim)

# Set default labels, filename, and title if not provided
if self.xlabel is None:
self.xlabel = "Residue ID"
Expand All @@ -306,6 +309,7 @@ def create_plot(self, res_ids=None, lipid_type=None, metric_name=None, **kwargs)
weight="bold",
fontfamily="Arial Rounded MT Bold",
)

plt.tight_layout()


Expand Down Expand Up @@ -496,7 +500,7 @@ def __init__(
self.universe = universe

def create_plot(
self, lipid_type=None, bins=150, size_in_mb=50000, frame=None, **kwargs
self, lipid_type=None, bins=150, size_in_mb=50000, start=0, stop=-1, step=1, **kwargs
):
"""
Plot the preferential localization of lipids using 2D density maps.
Expand Down Expand Up @@ -527,7 +531,7 @@ def create_plot(
raise ValueError("Please specify a lipid_type.")
else:
# Compute the lipid coordinates
computed_coords = compute_density(self.universe, lipid_type, size_in_mb)
computed_coords = compute_density(self.universe, lipid_type, size_in_mb, start, stop, step)

# Compute the lipid density using a 2D histogram
H, xe, ye = np.histogram2d(
Expand All @@ -538,17 +542,9 @@ def create_plot(
fig, ax = plt.subplots(figsize=self.fig_size)

# Plot the density map using imshow
if frame is None:
im = ax.imshow(
im = ax.imshow(
H, origin="lower", extent=[xe[0], xe[-1], ye[0], ye[-1]], **kwargs
)
else:
im = ax.imshow(
H[frame:-frame, frame:-frame],
origin="lower",
extent=[xe[0], xe[-1], ye[0], ye[-1]],
**kwargs,
)
ax.grid(False)

# Create a colorbar of the same size as the plot
Expand Down Expand Up @@ -796,6 +792,7 @@ def create_plot(
color_logo="silver",
palette="Blues",
res_ids=None,
limits=None,
**kwargs
):
"""
Expand Down Expand Up @@ -865,9 +862,12 @@ def ceildiv(number):

# Create colormap based on Metric and normalize values
cmap = mpl.cm.get_cmap(palette)
norm = mpl.colors.Normalize(
vmin=df["Metric"].min(), vmax=df["Metric"].max()
)
if limits is not None:
norm = mpl.colors.Normalize(vmin=limits[0], vmax=limits[1])
else:
norm = mpl.colors.Normalize(
vmin=df["Metric"].min(), vmax=df["Metric"].max()
)
colors = [cmap(norm(value)) for value in df["Metric"]]

# Process the subplot grid based on the number of rows
Expand Down Expand Up @@ -963,15 +963,19 @@ def ceildiv(number):

# Add color bar to the bottom
cax = plt.axes([0.1, -0.3 / (n_rows - n_rows / 2), 0.8, 0.12 / n_rows])
if limits is not None:
cb_ticks=[limits[0], limits[0] + (limits[1] - limits[0]) / 2, limits[1]]
else:
cb_ticks=[
df["Metric"].min(),
df["Metric"].min() + (df["Metric"].max() - df["Metric"].min()) / 2,
df["Metric"].max(),
]
cbar = plt.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
cax=cax,
orientation="horizontal",
ticks=[
df["Metric"].min(),
df["Metric"].min() + (df["Metric"].max() - df["Metric"].min()) / 2,
df["Metric"].max(),
],
ticks=cb_ticks,
)
cbar.set_label(
label=metric_name, size=12, fontfamily="Arial Rounded MT Bold"
Expand Down
8 changes: 4 additions & 4 deletions prolint2/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,9 @@ def get_metrics_for_radar(metrics, metric_names, resIDs=[], lipid=None):
return metrics_radar, metric_names


def compute_density(universe, lipid, size_in_mb):
def compute_density(universe, lipid, size_in_mb, start, stop, step):
# Get the total number of frames in the trajectory
frames = universe.trajectory.n_frames
frames = len(universe.trajectory[start: stop: step])

# Create a selection string for the specified lipid
selection_string = "resname {}".format(lipid)
Expand All @@ -449,8 +449,8 @@ def compute_density(universe, lipid, size_in_mb):
lipids_ag = universe.select_atoms(selection_string)

# Loop through the trajectory frames
for ts in universe.trajectory:
if ts.frame == 0:
for ts in universe.trajectory[start: stop: step]:
if ts.frame == start:
# Store the positions of the lipid atoms for the first frame
lipids_xyz = lipids_ag.positions
else:
Expand Down

0 comments on commit 78c903e

Please sign in to comment.