Spaces:
Runtime error
Runtime error
Commit
·
8eda766
1
Parent(s):
d099347
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,81 +11,6 @@ os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
|
|
| 11 |
os.system("cp -r SceneDreamer/* ./")
|
| 12 |
os.system("bash install.sh")
|
| 13 |
|
| 14 |
-
pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
|
| 15 |
-
alt_url='', file_size=330571863,
|
| 16 |
-
file_path='./scenedreamer_released.pt',)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
|
| 20 |
-
file_path = file_spec['file_path']
|
| 21 |
-
if use_alt_url:
|
| 22 |
-
file_url = file_spec['alt_url']
|
| 23 |
-
else:
|
| 24 |
-
file_url = file_spec['file_url']
|
| 25 |
-
|
| 26 |
-
file_dir = os.path.dirname(file_path)
|
| 27 |
-
tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
|
| 28 |
-
if file_dir:
|
| 29 |
-
os.makedirs(file_dir, exist_ok=True)
|
| 30 |
-
|
| 31 |
-
progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
|
| 32 |
-
for attempts_left in reversed(range(num_attempts)):
|
| 33 |
-
data_size = 0
|
| 34 |
-
progress_bar.reset()
|
| 35 |
-
try:
|
| 36 |
-
# Download.
|
| 37 |
-
data_md5 = hashlib.md5()
|
| 38 |
-
with session.get(file_url, stream=True) as res:
|
| 39 |
-
res.raise_for_status()
|
| 40 |
-
with open(tmp_path, 'wb') as f:
|
| 41 |
-
for chunk in res.iter_content(chunk_size=chunk_size<<10):
|
| 42 |
-
progress_bar.update(len(chunk))
|
| 43 |
-
f.write(chunk)
|
| 44 |
-
data_size += len(chunk)
|
| 45 |
-
data_md5.update(chunk)
|
| 46 |
-
|
| 47 |
-
# Validate.
|
| 48 |
-
if 'file_size' in file_spec and data_size != file_spec['file_size']:
|
| 49 |
-
raise IOError('Incorrect file size', file_path)
|
| 50 |
-
if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
|
| 51 |
-
raise IOError('Incorrect file MD5', file_path)
|
| 52 |
-
break
|
| 53 |
-
|
| 54 |
-
except Exception as e:
|
| 55 |
-
# print(e)
|
| 56 |
-
# Last attempt => raise error.
|
| 57 |
-
if not attempts_left:
|
| 58 |
-
raise
|
| 59 |
-
|
| 60 |
-
# Handle Google Drive virus checker nag.
|
| 61 |
-
if data_size > 0 and data_size < 8192:
|
| 62 |
-
with open(tmp_path, 'rb') as f:
|
| 63 |
-
data = f.read()
|
| 64 |
-
links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
|
| 65 |
-
if len(links) == 1:
|
| 66 |
-
file_url = requests.compat.urljoin(file_url, links[0])
|
| 67 |
-
continue
|
| 68 |
-
|
| 69 |
-
progress_bar.close()
|
| 70 |
-
|
| 71 |
-
# Rename temp file to the correct name.
|
| 72 |
-
os.replace(tmp_path, file_path) # atomic
|
| 73 |
-
|
| 74 |
-
# Attempt to clean up any leftover temps.
|
| 75 |
-
for filename in glob.glob(file_path + '.tmp.*'):
|
| 76 |
-
try:
|
| 77 |
-
os.remove(filename)
|
| 78 |
-
except:
|
| 79 |
-
pass
|
| 80 |
-
|
| 81 |
-
print('Downloading SceneDreamer pretrained model...')
|
| 82 |
-
with requests.Session() as session:
|
| 83 |
-
try:
|
| 84 |
-
download_file(session, pretrained_model)
|
| 85 |
-
except:
|
| 86 |
-
print('Google Drive download failed.\n')
|
| 87 |
-
|
| 88 |
-
|
| 89 |
|
| 90 |
import os
|
| 91 |
import torch
|
|
|
|
| 11 |
os.system("cp -r SceneDreamer/* ./")
|
| 12 |
os.system("bash install.sh")
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
import os
|
| 16 |
import torch
|