GeoRemover / app.py
zixinz
Add application file
5a0778e
raw
history blame
2.05 kB
import os
import pathlib
import subprocess
import gradio as gr
import spaces
import torch
# ---------- 权重下载:强制在 code_depth 下执行你的脚本 ----------
BASE_DIR = pathlib.Path(__file__).resolve().parent
SCRIPT_DIR = BASE_DIR / "code_depth"
GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
def ensure_executable(path: pathlib.Path):
if not path.exists():
raise FileNotFoundError(f"Download script not found: {path}")
os.chmod(path, os.stat(path).st_mode | 0o111)
def ensure_weights() -> str:
"""
在 code_depth 目录下运行 get_weights.sh。
该脚本会在 code_depth/ 下创建 checkpoints/ 并下载权重。
返回绝对路径:<repo_root>/code_depth/checkpoints
"""
ensure_executable(GET_WEIGHTS_SH)
# 你脚本的工作目录需要是 code_depth
subprocess.run(
["bash", str(GET_WEIGHTS_SH)],
check=True,
cwd=str(SCRIPT_DIR),
env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
)
ckpt_dir = SCRIPT_DIR / "checkpoints"
return str(ckpt_dir)
# 启动时先拉权重(不开 Persistent Storage 时,重建环境会清空;重启后会自动再拉一次)
try:
CKPT_DIR = ensure_weights()
print(f"✅ Weights ready in: {CKPT_DIR}")
except Exception as e:
print(f"⚠️ Failed to prepare weights: {e}")
CKPT_DIR = str(SCRIPT_DIR / "checkpoints") # 仍然给个路径,后续可检查是否存在
# ---------- Gradio 推理函数 ----------
@spaces.GPU
def greet(n: float):
# 在 GPU worker 里拿 device
device = "cuda" if torch.cuda.is_available() else "cpu"
zero = torch.tensor([0.0], device=device)
# 仅示例输出,你可以在这里用 CKPT_DIR 加载你的模型
print(f"Device in greet(): {device}")
print(f"Using checkpoints from: {CKPT_DIR}")
return f"Hello {(zero + n).item()} Tensor (device={device})"
demo = gr.Interface(fn=greet, inputs=gr.Number(label="n"), outputs=gr.Text())
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)