Skip to content

Commit

Permalink
[stat_update]
Browse files Browse the repository at this point in the history
-new func of plotting rcvived/missed batches.
-improve the plot of get_loss_ts
  • Loading branch information
NoaShapira8 committed Aug 8, 2024
1 parent 3942784 commit d7d227d
Showing 1 changed file with 58 additions and 2 deletions.
60 changes: 58 additions & 2 deletions src_py/apiServer/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,25 @@ def get_loss_ts(self , plot : bool = False , saveToFile : bool = False):
self.loss_ts_pd = df

if plot:
sns.set(style="whitegrid")
plt.figure(figsize=(12, 8))
plt.gca().set_facecolor('lightblue') # Set the background color to light blue

# Customize the grid lines to be black
plt.grid(color='black', linestyle='-', linewidth=0.5)

sns.lineplot(data=df)
plt.xlabel('Batch Num.')
plt.ylabel('Loss Value')
plt.title('Training Loss Function')
return df
plt.title(f'Training Loss Function of ({self.experiment_phase.get_name()})')

# Move legend outside of the plot
plt.legend(title="Worker", loc='center left', bbox_to_anchor=(1, 0.5), ncol=1)

if saveToFile:
plt.savefig('training_loss_function.png', bbox_inches='tight')

plt.show()

def get_min_loss(self , plot : bool = False , saveToFile : bool = False): # Todo change it
"""
Expand Down Expand Up @@ -331,6 +345,48 @@ def missed_batches_key(phase_name, source_name, worker_name):
break
return missed_batches_dict

def plot_batches_status(self, plot=False):
workers_model_db_list = self.nerl_model_db.get_workers_model_db_list()
workers_names = [worker_model_db.get_worker_name() for worker_model_db in workers_model_db_list]
received_batches = self.get_recieved_batches()
missed_batches = self.get_missed_batches()

# Initialize dictionaries to store batch counts for each worker
batches_received_train = {worker: 0 for worker in workers_names}
batches_dropped_train = {worker: 0 for worker in workers_names}

# Fill the dictionaries with the counts of received and missed batches
for key, batches in received_batches.items():
worker = key.split('->')[-1]
batches_received_train[worker] += len(batches)

for key, batches in missed_batches.items():
worker = key.split('->')[-1]
batches_dropped_train[worker] += len(batches)

# Create a DataFrame for plotting
workers_comm_dict = {
'Worker': list(batches_received_train.keys()),
'batches_received_train': list(batches_received_train.values()),
'batches_dropped_train': list(batches_dropped_train.values())
}
df_train = pd.DataFrame(workers_comm_dict)

# Sort the DataFrame by the worker names
df_train = df_train.sort_values(by='Worker')

# Plotting
if plot:
plt.figure(figsize=(10, 6))
data_train = pd.melt(df_train, id_vars=['Worker'], value_vars=['batches_received_train', 'batches_dropped_train'])
batches_stats = sns.barplot(x='Worker', y='value', hue='variable', data=data_train, order=sorted(workers_names))
plt.ylabel('Number Of Batches')
plt.xlabel('Worker')
plt.title(f"Received & Dropped Batches At Freq. 5B/s ({self.experiment_phase.get_name()})")

batches_stats.legend(loc='upper right', bbox_to_anchor=(1.5, 0.2), shadow=True, ncol=1)
plt.show()

def get_communication_stats_workers(self):
# return dictionary of {worker : {communication_stats}}
communication_stats_workers_dict = OrderedDict()
Expand Down

0 comments on commit d7d227d

Please sign in to comment.