Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import pathlib | |
| import subprocess | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| # ---------- 权重下载:强制在 code_depth 下执行你的脚本 ---------- | |
| BASE_DIR = pathlib.Path(__file__).resolve().parent | |
| SCRIPT_DIR = BASE_DIR / "code_depth" | |
| GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh" | |
| def ensure_executable(path: pathlib.Path): | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Download script not found: {path}") | |
| os.chmod(path, os.stat(path).st_mode | 0o111) | |
| def ensure_weights() -> str: | |
| """ | |
| 在 code_depth 目录下运行 get_weights.sh。 | |
| 该脚本会在 code_depth/ 下创建 checkpoints/ 并下载权重。 | |
| 返回绝对路径:<repo_root>/code_depth/checkpoints | |
| """ | |
| ensure_executable(GET_WEIGHTS_SH) | |
| # 你脚本的工作目录需要是 code_depth | |
| subprocess.run( | |
| ["bash", str(GET_WEIGHTS_SH)], | |
| check=True, | |
| cwd=str(SCRIPT_DIR), | |
| env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"}, | |
| ) | |
| ckpt_dir = SCRIPT_DIR / "checkpoints" | |
| return str(ckpt_dir) | |
| # 启动时先拉权重(不开 Persistent Storage 时,重建环境会清空;重启后会自动再拉一次) | |
| try: | |
| CKPT_DIR = ensure_weights() | |
| print(f"✅ Weights ready in: {CKPT_DIR}") | |
| except Exception as e: | |
| print(f"⚠️ Failed to prepare weights: {e}") | |
| CKPT_DIR = str(SCRIPT_DIR / "checkpoints") # 仍然给个路径,后续可检查是否存在 | |
| # ---------- Gradio 推理函数 ---------- | |
| def greet(n: float): | |
| # 在 GPU worker 里拿 device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| zero = torch.tensor([0.0], device=device) | |
| # 仅示例输出,你可以在这里用 CKPT_DIR 加载你的模型 | |
| print(f"Device in greet(): {device}") | |
| print(f"Using checkpoints from: {CKPT_DIR}") | |
| return f"Hello {(zero + n).item()} Tensor (device={device})" | |
| demo = gr.Interface(fn=greet, inputs=gr.Number(label="n"), outputs=gr.Text()) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |