Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve compatibility with pytorch < 1.10 #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dpvo/data_readers/rgbd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from ..lietorch import SE3
from ..utils import meshgrid

from scipy.spatial.transform import Rotation

Expand Down Expand Up @@ -111,7 +112,7 @@ def compute_distance_matrix_flow(poses, disps, intrinsics):

N = poses.shape[1]

ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii, jj = meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1).cuda()
jj = jj.reshape(-1).cuda()

Expand Down Expand Up @@ -151,7 +152,7 @@ def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):

N = poses.shape[1]

ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii, jj = meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1)
jj = jj.reshape(-1)

Expand Down
12 changes: 6 additions & 6 deletions dpvo/dpvo.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,24 +296,24 @@ def update(self):
self.points_[:len(points)] = points[:]

def __edges_all(self):
return flatmeshgrid(
return meshgrid(
torch.arange(0, self.m, device="cuda"),
torch.arange(0, self.n, device="cuda"), indexing='ij')
torch.arange(0, self.n, device="cuda"), indexing="ij", flat=True)

def __edges_forw(self):
r=self.cfg.PATCH_LIFETIME
t0 = self.M * max((self.n - r), 0)
t1 = self.M * max((self.n - 1), 0)
return flatmeshgrid(
return meshgrid(
torch.arange(t0, t1, device="cuda"),
torch.arange(self.n-1, self.n, device="cuda"), indexing='ij')
torch.arange(self.n-1, self.n, device="cuda"), indexing="ij", flat=True)

def __edges_back(self):
r=self.cfg.PATCH_LIFETIME
t0 = self.M * max((self.n - 1), 0)
t1 = self.M * max((self.n - 0), 0)
return flatmeshgrid(torch.arange(t0, t1, device="cuda"),
torch.arange(max(self.n-r, 0), self.n, device="cuda"), indexing='ij')
return meshgrid(torch.arange(t0, t1, device="cuda"),
torch.arange(max(self.n-r, 0), self.n, device="cuda"), indexing="ij", flat=True)

def __call__(self, tstamp, image, intrinsics):
""" track new frame """
Expand Down
6 changes: 3 additions & 3 deletions dpvo/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def forward(self, images, poses, disps, intrinsics, M=1024, STEPS=12, P=1, struc
d = patches[..., 2, p//2, p//2]
patches = set_depth(patches, torch.rand_like(d))

kk, jj = flatmeshgrid(torch.where(ix < 8)[0], torch.arange(0,8, device="cuda"))
kk, jj = meshgrid(torch.where(ix < 8)[0], torch.arange(0,8, device="cuda"), flat=True)
ii = ix[kk]

imap = imap.view(b, -1, DIM)
Expand All @@ -222,8 +222,8 @@ def forward(self, images, poses, disps, intrinsics, M=1024, STEPS=12, P=1, struc
n = ii.max() + 1
if len(traj) >= 8 and n < images.shape[1]:
if not structure_only: Gs.data[:,n] = Gs.data[:,n-1]
kk1, jj1 = flatmeshgrid(torch.where(ix < n)[0], torch.arange(n, n+1, device="cuda"))
kk2, jj2 = flatmeshgrid(torch.where(ix == n)[0], torch.arange(0, n+1, device="cuda"))
kk1, jj1 = meshgrid(torch.where(ix < n)[0], torch.arange(n, n+1, device="cuda"), flat=True)
kk2, jj2 = meshgrid(torch.where(ix == n)[0], torch.arange(0, n+1, device="cuda"), flat=True)

ii = torch.cat([ix[kk1], ix[kk2], ii])
jj = torch.cat([jj1, jj2, jj])
Expand Down
3 changes: 2 additions & 1 deletion dpvo/projective_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import torch.nn.functional as F

from .lietorch import SE3, Sim3
from .utils import meshgrid

MIN_DEPTH = 0.2

def extract_intrinsics(intrinsics):
return intrinsics[...,None,None,:].unbind(dim=-1)

def coords_grid(ht, wd, **kwargs):
y, x = torch.meshgrid(
y, x = meshgrid(
torch.arange(ht).to(**kwargs).float(),
torch.arange(wd).to(**kwargs).float())

Expand Down
18 changes: 11 additions & 7 deletions dpvo/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F

TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])

all_times = []

Expand Down Expand Up @@ -32,7 +33,7 @@ def coords_grid(b, n, h, w, **kwargs):
""" coordinate grid """
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)
coords = torch.stack(torch.meshgrid(y, x, indexing="ij"))
coords = torch.stack(meshgrid(y, x, indexing="ij"))
return coords[[1,0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1)

def coords_grid_with_index(d, **kwargs):
Expand All @@ -42,7 +43,7 @@ def coords_grid_with_index(d, **kwargs):
x = torch.arange(0, w, dtype=torch.float, **kwargs)
y = torch.arange(0, h, dtype=torch.float, **kwargs)

y, x = torch.stack(torch.meshgrid(y, x, indexing="ij"))
y, x = torch.stack(meshgrid(y, x, indexing="ij"))
y = y.view(1, 1, h, w).repeat(b, n, 1, 1)
x = x.view(1, 1, h, w).repeat(b, n, 1, 1)

Expand Down Expand Up @@ -73,15 +74,18 @@ def pyramidify(fmap, lvls=[1]):
return pyramid

def all_pairs_exclusive(n, **kwargs):
ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
ii, jj = meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
k = ii != jj
return ii[k].reshape(-1), jj[k].reshape(-1)

def set_depth(patches, depth):
patches[...,2,:,:] = depth[...,None,None]
return patches

def flatmeshgrid(*args, **kwargs):
grid = torch.meshgrid(*args, **kwargs)
return (x.reshape(-1) for x in grid)

def meshgrid(x, y, indexing="ij", flat=False):
if TORCH_VERSION < (1, 10):
# "ij" if default in older torch
grid = torch.meshgrid(x, y)
else:
grid = torch.meshgrid(x, y, indexing)
return (x.reshape(-1) for x in grid) if flat else grid
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from dpvo.net import VONet
from evaluate_tartan import evaluate as validate

from dpvo.utils import meshgrid


def show_image(image):
image = image.permute(1, 2, 0).cpu().numpy()
Expand Down Expand Up @@ -88,7 +90,7 @@ def train(args):
e = e.reshape(-1, 9)[(v > 0.5).reshape(-1)].min(dim=-1).values

N = P1.shape[1]
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii, jj = meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1).cuda()
jj = jj.reshape(-1).cuda()

Expand Down