Skip to content

Commit

Permalink
pytorch version fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eps696 committed Nov 12, 2021
1 parent 22201e8 commit bcf3daa
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 17 deletions.
5 changes: 2 additions & 3 deletions Aphantasia.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
"!git clone https://github.com/eps696/aphantasia\n",
"%cd aphantasia/\n",
"from clip_fft import to_valid_rgb, fft_image, img2fft, dwt_image, img2dwt, init_dwt, dwt_scale\n",
"from utils import slice_imgs, derivat, pad_up_to, basename, img_list, img_read, plot_text, txt_clean, checkout\n",
"from utils import slice_imgs, derivat, pad_up_to, basename, img_list, img_read, plot_text, txt_clean, checkout, old_torch\n",
"import transforms\n",
"from progress_bar import ProgressIPy as ProgressBar\n",
"\n",
Expand Down Expand Up @@ -321,8 +321,7 @@
" samples = int(samples * 0.75)\n",
"print(' using %d samples' % samples)\n",
"\n",
"use_jit = True if float(torch.__version__[:3]) < 1.8 else False\n",
"model_clip, _ = clip.load(model, jit=use_jit)\n",
"model_clip, _ = clip.load(model, jit=old_torch())\n",
"modsize = model_clip.visual.input_resolution\n",
"xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}\n",
"if model in xmem.keys():\n",
Expand Down
5 changes: 2 additions & 3 deletions CLIP_VQGAN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"!rm -rf aphantasia\n",
"!git clone https://github.com/eps696/aphantasia\n",
"%cd aphantasia/\n",
"from utils import slice_imgs, pad_up_to, basename, img_list, img_read, plot_text, txt_clean\n",
"from utils import slice_imgs, pad_up_to, basename, img_list, img_read, plot_text, txt_clean, old_torch\n",
"import transforms\n",
"from progress_bar import ProgressIPy as ProgressBar\n",
"\n",
Expand Down Expand Up @@ -339,8 +339,7 @@
" samples = int(samples * 0.5)\n",
"print(' using %d samples' % samples)\n",
"\n",
"use_jit = True if float(torch.__version__[:3]) < 1.8 else False\n",
"model_clip, _ = clip.load(model, jit=use_jit)\n",
"model_clip, _ = clip.load(model, jit=old_torch())\n",
"modsize = model_clip.visual.input_resolution\n",
"xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}\n",
"if model in xmem.keys():\n",
Expand Down
5 changes: 2 additions & 3 deletions Illustra.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
"!git clone https://github.com/eps696/aphantasia\n",
"%cd aphantasia/\n",
"from clip_fft import to_valid_rgb, fft_image\n",
"from utils import slice_imgs, derivat, pad_up_to, basename, file_list, img_list, img_read, txt_clean, plot_text, checkout \n",
"from utils import slice_imgs, derivat, pad_up_to, basename, file_list, img_list, img_read, txt_clean, plot_text, checkout, old_torch \n",
"import transforms\n",
"from progress_bar import ProgressIPy as ProgressBar\n",
"\n",
Expand Down Expand Up @@ -260,8 +260,7 @@
"fps = 25\n",
"if multilang: model = 'ViT-B/32' # sbert model is trained with ViT\n",
"\n",
"use_jit = True if float(torch.__version__[:3]) < 1.8 else False\n",
"model_clip, _ = clip.load(model, jit=use_jit)\n",
"model_clip, _ = clip.load(model, jit=old_torch())\n",
"modsize = model_clip.visual.input_resolution\n",
"xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}\n",
"if model in xmem.keys():\n",
Expand Down
9 changes: 4 additions & 5 deletions clip_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sentence_transformers import SentenceTransformer
import lpips

from utils import slice_imgs, derivat, basename, img_list, img_read, plot_text, txt_clean, checkout
from utils import slice_imgs, derivat, basename, img_list, img_read, plot_text, txt_clean, checkout, old_torch
import transforms
try: # progress bar for notebooks
get_ipython().__class__.__name__
Expand Down Expand Up @@ -234,7 +234,7 @@ def inner(shift=None, contrast=1.):
scaled_spectrum_t = scale * spectrum_real_imag_t
if shift is not None:
scaled_spectrum_t += scale * shift
if float(torch.__version__[:3]) < 1.8:
if old_torch():
image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w))
else:
if type(scaled_spectrum_t) is not torch.complex64:
Expand Down Expand Up @@ -279,7 +279,7 @@ def img2fft(img_in, decay=1., colors=1.):
img_in = un_rgb(img_in, colors=colors)

with torch.no_grad():
if float(torch.__version__[:3]) < 1.8:
if old_torch():
spectrum = torch.rfft(img_in, 2, normalized=True) # 1.7
else:
spectrum = torch.fft.rfftn(img_in, s=(h, w), dim=[2,3], norm='ortho') # 1.8
Expand Down Expand Up @@ -352,8 +352,7 @@ def train(i):
pbar.upd()

# Load CLIP models
use_jit = True if float(torch.__version__[:3]) < 1.8 else False
model_clip, _ = clip.load(a.model, jit=use_jit)
model_clip, _ = clip.load(a.model, jit=old_torch())
try:
a.modsize = model_clip.visual.input_resolution
except:
Expand Down
5 changes: 2 additions & 3 deletions illustra.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sentence_transformers import SentenceTransformer

from clip_fft import to_valid_rgb, fft_image
from utils import slice_imgs, derivat, checkout, cvshow, pad_up_to, basename, file_list, img_list, img_read, txt_clean, plot_text
from utils import slice_imgs, derivat, checkout, cvshow, pad_up_to, basename, file_list, img_list, img_read, txt_clean, plot_text, old_torch
import transforms
try: # progress bar for notebooks
get_ipython().__class__.__name__
Expand Down Expand Up @@ -82,8 +82,7 @@ def main():
a = get_args()

# Load CLIP models
use_jit = True if float(torch.__version__[:3]) < 1.8 else False
model_clip, _ = clip.load(a.model, jit=use_jit)
model_clip, _ = clip.load(a.model, jit=old_torch())
try:
a.modsize = model_clip.visual.input_resolution
except:
Expand Down
4 changes: 4 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def plot_text(txt, size=224):
def txt_clean(txt):
return txt.translate(str.maketrans(dict.fromkeys(list("\n',.—|!?/:;\\"), ""))).replace(' ', '_').replace('"', '')

def old_torch():
ver = [int(i) for i in torch.__version__.split('.')[:2]]
return True if (ver[0] < 2 and ver[1] < 8) else False

def basename(file):
return os.path.splitext(os.path.basename(file))[0]

Expand Down

0 comments on commit bcf3daa

Please sign in to comment.