File size: 2,049 Bytes
5a0778e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import pathlib
import subprocess
import gradio as gr
import spaces
import torch

# ---------- 权重下载:强制在 code_depth 下执行你的脚本 ----------
BASE_DIR = pathlib.Path(__file__).resolve().parent
SCRIPT_DIR = BASE_DIR / "code_depth"
GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"

def ensure_executable(path: pathlib.Path):
    if not path.exists():
        raise FileNotFoundError(f"Download script not found: {path}")
    os.chmod(path, os.stat(path).st_mode | 0o111)

def ensure_weights() -> str:
    """
    在 code_depth 目录下运行 get_weights.sh。
    该脚本会在 code_depth/ 下创建 checkpoints/ 并下载权重。
    返回绝对路径:<repo_root>/code_depth/checkpoints
    """
    ensure_executable(GET_WEIGHTS_SH)
    # 你脚本的工作目录需要是 code_depth
    subprocess.run(
        ["bash", str(GET_WEIGHTS_SH)],
        check=True,
        cwd=str(SCRIPT_DIR),
        env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
    )
    ckpt_dir = SCRIPT_DIR / "checkpoints"
    return str(ckpt_dir)

# 启动时先拉权重(不开 Persistent Storage 时,重建环境会清空;重启后会自动再拉一次)
try:
    CKPT_DIR = ensure_weights()
    print(f"✅ Weights ready in: {CKPT_DIR}")
except Exception as e:
    print(f"⚠️ Failed to prepare weights: {e}")
    CKPT_DIR = str(SCRIPT_DIR / "checkpoints")  # 仍然给个路径,后续可检查是否存在

# ---------- Gradio 推理函数 ----------
@spaces.GPU
def greet(n: float):
    # 在 GPU worker 里拿 device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    zero = torch.tensor([0.0], device=device)
    # 仅示例输出,你可以在这里用 CKPT_DIR 加载你的模型
    print(f"Device in greet(): {device}")
    print(f"Using checkpoints from: {CKPT_DIR}")
    return f"Hello {(zero + n).item()} Tensor (device={device})"

demo = gr.Interface(fn=greet, inputs=gr.Number(label="n"), outputs=gr.Text())
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)