Skip to content

Commit

Permalink
Remove unnecessary codes added in bc36c8d
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Sep 8, 2021
1 parent 8e764fe commit 2524ce5
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion vsrife/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def RIFE(clip: vs.VideoNode, model_ver: float=3.5, scale: float=1.0, device_type
from .model38.RIFE_HDv3 import Model
model_dir = 'model38'

model = Model(device, fp16)
model = Model(device)
model.load_model(os.path.join(os.path.dirname(__file__), model_dir), -1)
model.eval()
model.device()
Expand Down
8 changes: 2 additions & 6 deletions vsrife/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ def forward(self, flow, gt, loss_mask):


class Ternary(nn.Module):
def __init__(self, device, fp16):
def __init__(self, device):
super(Ternary, self).__init__()
patch_size = 7
out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape(
(patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device)
if fp16:
self.w = self.w.half()

def transform(self, img):
patches = F.conv2d(img, self.w, padding=3, bias=None)
Expand Down Expand Up @@ -56,15 +54,13 @@ def forward(self, img0, img1):


class SOBEL(nn.Module):
def __init__(self, device, fp16):
def __init__(self, device):
super(SOBEL, self).__init__()
self.kernelX = torch.tensor([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1],
]).float()
if fp16:
self.kernelX = self.kernelX.half()
self.kernelY = self.kernelX.clone().T
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device)
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device)
Expand Down
6 changes: 3 additions & 3 deletions vsrife/model31/RIFE_HDv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def forward(self, img0, img1, flow, c0, c1, flow_gt):


class Model:
def __init__(self, device, fp16, local_rank=-1):
def __init__(self, device, local_rank=-1):
self.torch_device = device
self.flownet = IFNet(device)
self.contextnet = ContextNet(device)
Expand All @@ -125,8 +125,8 @@ def __init__(self, device, fp16, local_rank=-1):
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary(device, fp16)
self.sobel = SOBEL(device, fp16)
self.ter = Ternary(device)
self.sobel = SOBEL(device)
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
Expand Down
4 changes: 2 additions & 2 deletions vsrife/model35/RIFE_HDv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from vsrife.loss import *

class Model:
def __init__(self, device, fp16, local_rank=-1):
def __init__(self, device, local_rank=-1):
self.torch_device = device
self.flownet = IFNet(device)
self.device()
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
# self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL(device, fp16)
self.sobel = SOBEL(device)
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)

Expand Down
4 changes: 2 additions & 2 deletions vsrife/model38/RIFE_HDv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
from vsrife.loss import *

class Model:
def __init__(self, device, fp16, local_rank=-1):
def __init__(self, device, local_rank=-1):
self.torch_device = device
self.flownet = IFNet(device)
self.device()
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
# self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL(device, fp16)
self.sobel = SOBEL(device)
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)

Expand Down

0 comments on commit 2524ce5

Please sign in to comment.