Skip to content

Commit

Permalink
Update Dataloader.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ZooBeasts authored Aug 28, 2024
1 parent d8f515b commit 8602dd0
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions Dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from sklearn import preprocessing

Scaler = preprocessing.MinMaxScaler()
Scaler = preprocessing.MinMaxScaler(feature_range=(-1,1))

dataindex = 201

Expand Down Expand Up @@ -46,16 +46,13 @@ def __getitem__(self, index):
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img)
points21 = item[1:dataindex].astype(np.float64).reshape(-1, 1)
# points21 = item[1:dataindex].astype(np.float64)
points21 = Scaler.fit_transform(points21)
points21 = torch.from_numpy(points21).flatten(0)
# points21 = torch.from_numpy(points21)

points = item[1:dataindex].astype(np.float64).reshape(-1,1)
# points = item[1:dataindex].astype(np.float64)
# points = Scaler.fit_transform(points)
points = Scaler.fit_transform(points)
points = torch.from_numpy(points).flatten(0)
# points = torch.from_numpy(points)
assert len(points) <= self.z_dim
points = torch.hstack([points, torch.randn(self.z_dim - len(points))])
points = points.reshape([self.z_dim, 1, 1])
Expand Down

0 comments on commit 8602dd0

Please sign in to comment.