sharath88 commited on
Commit
f06a85e
·
1 Parent(s): 75d3db4

Switch to Hugging Face InferenceClient for chat backend

Browse files

Replaced requests-based API calls with huggingface_hub's InferenceClient for model inference, updated model selection to Gemma, and refactored prompt construction and persona handling. Added CORS middleware and removed template/static serving for a pure API backend. Updated requirements.txt to include huggingface_hub.

Files changed (2) hide show
  1. main.py +98 -85
  2. requirements.txt +1 -0
main.py CHANGED
@@ -1,116 +1,129 @@
1
  import os
2
- import requests
3
- from typing import List, Literal, Optional
4
 
5
- from fastapi import FastAPI, Request
6
- from fastapi.responses import HTMLResponse, JSONResponse
7
- from fastapi.staticfiles import StaticFiles
8
- from fastapi.templating import Jinja2Templates
9
  from pydantic import BaseModel
 
10
 
11
- # -------------------- Config --------------------
12
 
13
- HF_MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"
14
- HF_API_TOKEN = os.getenv("HF_API_TOKEN") # set in Space → Settings → Secrets
15
 
16
- HF_API_URL = (
17
- f"https://router.huggingface.co/hf-inference/models/"
18
- f"{HF_MODEL_ID}/v1/chat/completions"
19
- )
20
-
21
- DEFAULT_SYSTEM_PROMPT = (
22
- "You are a helpful, concise AI assistant. "
23
- "Answer clearly in plain English unless the user asks otherwise."
24
- )
25
-
26
- if HF_API_TOKEN is None:
27
  raise RuntimeError(
28
- "HF_API_TOKEN is not set. Add it in Space settings Variables & secrets."
 
29
  )
30
 
31
- # -------------------- FastAPI setup --------------------
 
32
 
33
- app = FastAPI()
34
 
35
- # serve /static and /templates
36
- app.mount("/static", StaticFiles(directory="static"), name="static")
37
- templates = Jinja2Templates(directory="templates")
38
 
 
 
 
 
 
 
 
39
 
40
- class ChatMessage(BaseModel):
41
- role: Literal["user", "assistant", "system"]
42
- content: str
43
 
 
 
 
44
 
45
  class ChatRequest(BaseModel):
46
- messages: List[ChatMessage]
47
  temperature: float = 0.7
48
- max_new_tokens: int = 256
49
- system_prompt: Optional[str] = None
50
-
51
 
52
- # ------------- Routes -------------
 
 
53
 
54
- @app.get("/", response_class=HTMLResponse)
55
- async def home(request: Request):
56
- # This renders templates/index.html instead of JSON
57
- return templates.TemplateResponse("index.html", {"request": request})
58
 
 
59
 
60
- # ------------- HF Router helper -------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- def call_hf_chat(req: ChatRequest) -> str:
63
  """
64
- Call Zephyr via the new HF router chat-completions API
65
- (OpenAI-style).
66
  """
67
- system_prompt = req.system_prompt or DEFAULT_SYSTEM_PROMPT
68
-
69
- # prepend system message
70
- messages = [{"role": "system", "content": system_prompt}]
71
- for m in req.messages:
72
- messages.append({"role": m.role, "content": m.content})
73
-
74
- # clamp params to safe values
75
- temperature = max(0.1, min(req.temperature, 1.5))
76
- max_tokens = max(32, min(req.max_new_tokens, 512))
77
-
78
- headers = {
79
- "Authorization": f"Bearer {HF_API_TOKEN}",
80
- "Content-Type": "application/json",
81
- }
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- body = {
84
- "messages": messages,
85
- "temperature": temperature,
86
- "max_tokens": max_tokens,
87
- "stream": False,
88
- }
89
 
90
- resp = requests.post(HF_API_URL, headers=headers, json=body, timeout=60)
 
 
91
 
92
- if not resp.ok:
93
- raise RuntimeError(f"Inference API error {resp.status_code}: {resp.text}")
 
 
 
 
 
94
 
95
- data = resp.json()
96
- # OpenAI-style: choices[0].message.content
97
- try:
98
- return data["choices"][0]["message"]["content"].strip()
99
- except Exception:
100
- raise RuntimeError(f"Unexpected response format: {data}")
101
 
 
 
102
 
103
- @app.post("/chat")
104
- async def chat_endpoint(payload: ChatRequest):
105
- if not payload.messages:
106
- return JSONResponse(
107
- {"reply": "", "error": "No messages provided."}, status_code=400
108
- )
109
 
110
- try:
111
- reply = call_hf_chat(payload)
112
- return {"reply": reply}
113
- except Exception as e:
114
- return JSONResponse(
115
- {"reply": "", "error": str(e)}, status_code=500
116
- )
 
1
  import os
2
+ from typing import List, Literal, Dict, Any
 
3
 
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
 
 
6
  from pydantic import BaseModel
7
+ from huggingface_hub import InferenceClient
8
 
9
+ # ---------- Config ----------
10
 
11
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Set this in Space secrets
12
+ MODEL_ID = "google/gemma-2-2b-it" # Medium-sized instruct model
13
 
14
+ if HF_TOKEN is None:
 
 
 
 
 
 
 
 
 
 
15
  raise RuntimeError(
16
+ "HF_TOKEN is not set. Go to Space → Settings Repository secrets and "
17
+ "add HF_TOKEN with your Hugging Face access token."
18
  )
19
 
20
+ # Inference client (uses HF Inference API / router under the hood)
21
+ hf_client = InferenceClient(model=MODEL_ID, token=HF_TOKEN)
22
 
23
+ # ---------- FastAPI setup ----------
24
 
25
+ app = FastAPI(title="Zephyr Chat Demo (Gemma backend)")
 
 
26
 
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # ok for demo
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
 
35
+ # ---------- Data models ----------
36
+
37
+ Role = Literal["user", "assistant", "system"]
38
 
39
+ class Message(BaseModel):
40
+ role: Role
41
+ content: str
42
 
43
  class ChatRequest(BaseModel):
44
+ messages: List[Message]
45
  temperature: float = 0.7
46
+ max_tokens: int = 256
47
+ persona: str = "General Assistant"
 
48
 
49
+ class ChatResponse(BaseModel):
50
+ reply: str
51
+ messages: List[Message]
52
 
53
+ # ---------- Simple in-memory sessions ----------
 
 
 
54
 
55
+ sessions: Dict[str, List[Message]] = {}
56
 
57
+ def build_system_prompt(persona: str) -> str:
58
+ if persona == "Code Helper":
59
+ return (
60
+ "You are a helpful coding assistant. Explain things clearly, "
61
+ "show small code snippets, and avoid hallucinating libraries or APIs."
62
+ )
63
+ elif persona == "Data Tutor":
64
+ return (
65
+ "You are a teacher who explains data, statistics, and ML concepts "
66
+ "with simple examples and step-by-step reasoning."
67
+ )
68
+ else:
69
+ return (
70
+ "You are a friendly, concise AI assistant. "
71
+ "Answer clearly and avoid unsafe or speculative advice."
72
+ )
73
 
74
+ def build_prompt(messages: List[Message], persona: str) -> str:
75
  """
76
+ Convert chat history into a single text-generation prompt.
 
77
  """
78
+ system_prompt = build_system_prompt(persona)
79
+ lines = [f"System: {system_prompt}", ""]
80
+ for m in messages:
81
+ prefix = "User" if m.role == "user" else "Assistant" if m.role == "assistant" else "System"
82
+ lines.append(f"{prefix}: {m.content}")
83
+ lines.append("Assistant:")
84
+ return "\n".join(lines)
85
+
86
+ def call_llm(prompt: str, temperature: float, max_tokens: int) -> str:
87
+ """
88
+ Call HF Inference text-generation endpoint via InferenceClient.
89
+ """
90
+ try:
91
+ text = hf_client.text_generation(
92
+ prompt,
93
+ max_new_tokens=max_tokens,
94
+ temperature=temperature,
95
+ do_sample=True,
96
+ repetition_penalty=1.1,
97
+ return_full_text=False, # only new assistant text
98
+ )
99
+ return text.strip()
100
+ except Exception as e:
101
+ raise HTTPException(
102
+ status_code=500,
103
+ detail=f"Inference API error: {e}"
104
+ )
105
 
106
+ # ---------- Routes ----------
 
 
 
 
 
107
 
108
+ @app.get("/")
109
+ def health():
110
+ return {"status": "ok", "message": "Zephyr Chat Demo backend running."}
111
 
112
+ @app.post("/chat", response_model=ChatResponse)
113
+ def chat(req: ChatRequest):
114
+ """
115
+ Main chat endpoint. Frontend sends full message list each time.
116
+ """
117
+ if not req.messages:
118
+ raise HTTPException(400, "No messages provided.")
119
 
120
+ # Build prompt from conversation
121
+ prompt = build_prompt(req.messages, req.persona)
 
 
 
 
122
 
123
+ # Call model
124
+ reply_text = call_llm(prompt, req.temperature, req.max_tokens)
125
 
126
+ # Append assistant reply to conversation
127
+ new_messages = req.messages + [Message(role="assistant", content=reply_text)]
 
 
 
 
128
 
129
+ return ChatResponse(reply=reply_text, messages=new_messages)
 
 
 
 
 
 
requirements.txt CHANGED
@@ -3,3 +3,4 @@ uvicorn[standard]
3
  jinja2
4
  requests
5
  python-dotenv
 
 
3
  jinja2
4
  requests
5
  python-dotenv
6
+ huggingface_hub