Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ZooBeasts authored Mar 24, 2024
1 parent 12691fe commit 6b82d90
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
83 changes: 83 additions & 0 deletions Model/Model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import torch.nn as nn


nc = 1
image_size = 64
ngpu = 1
features_d = 64
features_g = 64
Z_dim = 250
channels_noise = Z_dim



class Critic(nn.Module):
def __init__(self, ngpu):
super(Critic, self).__init__()
self.ngpu = ngpu
self.image_size = image_size
self.l1 = nn.Linear(200, image_size * image_size * nc)
self.disc = nn.Sequential(
nn.Conv2d(nc * 2, features_d, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
self._block(features_d, features_d * 2, 4, 2, 1),
self._block(features_d * 2, features_d * 4, 4, 2, 1),
self._block(features_d * 4, features_d * 8, 4, 2, 1),
nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=1, padding=0),
# nn.Sigmoid(),
)

def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
),
nn.InstanceNorm2d(out_channels, affine=True),
nn.LeakyReLU(0.2, inplace=True),
)

def forward(self, img, points21):
x1 = img
x2 = self.l1(points21)
# x2 = x2.reshape(int(b_size / ngpu), nc, image_size, image_size)
x2 = x2.reshape(-1, nc, image_size, image_size)
combine = torch.cat((x1, x2), dim=1)
return self.disc(combine)


class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.net = nn.Sequential(
# Input: N x channels_noise x 1 x 1
self._block(channels_noise, features_g * 16, 4, 1, 0), # img: 4x4
self._block(features_g * 16, features_g * 8, 4, 2, 1), # img: 8x8
self._block(features_g * 8, features_g * 4, 4, 2, 1), # img: 16x16
self._block(features_g * 4, features_g * 2, 4, 2, 1), # img: 32x32
nn.ConvTranspose2d(
features_g * 2, nc, kernel_size=4, stride=2, padding=1
),
# Output: N x channels_img x 64 x 64
nn.Tanh(),
)

def _block(self, in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride, padding, bias=False,
),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)

def forward(self, points):

return self.net(points)


def initialize_weights(model):
for m in model.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
25 changes: 25 additions & 0 deletions Model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Disable TF deprecation warnings.
# Syntax from tf1 is not expected to be compatible with tf2.
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# Algorithms
from spinup.algos.tf1.ddpg.ddpg import ddpg as ddpg_tf1
from spinup.algos.tf1.ppo.ppo import ppo as ppo_tf1
from spinup.algos.tf1.sac.sac import sac as sac_tf1
from spinup.algos.tf1.td3.td3 import td3 as td3_tf1
from spinup.algos.tf1.trpo.trpo import trpo as trpo_tf1
from spinup.algos.tf1.vpg.vpg import vpg as vpg_tf1

from spinup.algos.pytorch.ddpg.ddpg import ddpg as ddpg_pytorch
from spinup.algos.pytorch.ppo.ppo import ppo as ppo_pytorch
from spinup.algos.pytorch.sac.sac import sac as sac_pytorch
from spinup.algos.pytorch.td3.td3 import td3 as td3_pytorch
from spinup.algos.pytorch.trpo.trpo import trpo as trpo_pytorch
from spinup.algos.pytorch.vpg.vpg import vpg as vpg_pytorch

# Loggers
from spinup.utils.logx import Logger, EpochLogger

# Version
from spinup.version import __version__

0 comments on commit 6b82d90

Please sign in to comment.