Skip to content

Commit

Permalink
fixed final plot
Browse files Browse the repository at this point in the history
  • Loading branch information
menelaoszetas1990 authored Apr 16, 2023
1 parent 5707383 commit 0ef16c8
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions LSTM.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Preprocessing

# Importing the dataset
# Using columns: sog, stw, wspeedbf, wdir, me_power
# Using columns: sog, stw, wspeedbf, wdir, me_power
dataset = pd.read_csv('data/cape' + str(WHICH_CAPE) + '.csv', usecols=[0, 1, 4, 5, 9])

X = dataset.values
Expand Down Expand Up @@ -69,7 +69,7 @@ def to_sequences(_dataset, _seq_size=1):
model.summary()

# fit the model
history = model.fit(trainX, trainY, epochs=5, batch_size=16, validation_split=0.1, verbose=1)
history = model.fit(trainX, trainY, epochs=2, batch_size=16, validation_split=0.1, verbose=1)

plt.plot(history.history['loss'], label='Training loss')
plt.plot(history.history['val_loss'], label='Validation loss')
Expand Down Expand Up @@ -106,12 +106,12 @@ def to_sequences(_dataset, _seq_size=1):
# shift test predictions for plotting
testPredictPlot = np.empty_like(dataset)
testPredictPlot[:, :] = np.nan
testPredictPlot[seq_size:len(testPredict) + seq_size, :] = testPredict
testPredictPlot[len(dataset)-len(testPredict):, :] = testPredict

# plot baseline and predictions
plt.plot(dataset['me_power'], label='dataset')
plt.plot(trainPredictPlot, label='trainPredict')
plt.plot(testPredictPlot, label='testPredict')
plt.plot(trainPredictPlot[:, 4:], label='trainPredict')
plt.plot(testPredictPlot[:, 4:], label='testPredict')
plt.legend()
plt.show()
print('end')

0 comments on commit 0ef16c8

Please sign in to comment.