Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,7 @@ import asyncio
|
|
| 6 |
import json
|
| 7 |
import time
|
| 8 |
import logging
|
|
|
|
| 9 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple, Union
|
| 10 |
from fastapi import FastAPI, HTTPException, Depends, status
|
| 11 |
from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse
|
|
@@ -115,6 +116,7 @@ class GenerateRequest(BaseModel):
|
|
| 115 |
use_cache: bool = Field(True)
|
| 116 |
do_sample: bool = Field(True)
|
| 117 |
tokenizer_kwargs: Optional[Dict[str, Any]] = None
|
|
|
|
| 118 |
max_time: Optional[float] = Field(None, ge=0.0)
|
| 119 |
length_penalty: float = Field(1.0, ge=0.0)
|
| 120 |
no_repeat_ngram_size: int = Field(0, ge=0)
|
|
@@ -134,6 +136,7 @@ class GenerateRequest(BaseModel):
|
|
| 134 |
length_normalization_factor: Optional[float] = Field(None)
|
| 135 |
min_new_tokens: int = Field(0, ge=0)
|
| 136 |
do_normalize_logits: bool = Field(False)
|
|
|
|
| 137 |
@validator('stop_sequences')
|
| 138 |
def validate_stop_sequences(cls, v):
|
| 139 |
if v is not None:
|
|
@@ -314,6 +317,8 @@ async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tenso
|
|
| 314 |
final_text_raw = final_text_raw.split(stop_seq, 1)[0]
|
| 315 |
break
|
| 316 |
final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
|
|
|
|
|
|
|
| 317 |
final_payload: Dict[str, Any] = {
|
| 318 |
"type": "done",
|
| 319 |
"total_prompt_tokens": initial_ids.shape[-1],
|
|
@@ -324,8 +329,11 @@ async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tenso
|
|
| 324 |
}
|
| 325 |
yield json.dumps(final_payload) + "\n"
|
| 326 |
except Exception as e:
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
| 329 |
finally:
|
| 330 |
await cleanup()
|
| 331 |
async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]:
|
|
@@ -675,6 +683,10 @@ async def generate_endpoint(req: GenerateRequest):
|
|
| 675 |
return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="application/json")
|
| 676 |
else:
|
| 677 |
response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
return JSONResponse(response_payload)
|
| 679 |
except Exception as e:
|
| 680 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
|
|
|
|
| 6 |
import json
|
| 7 |
import time
|
| 8 |
import logging
|
| 9 |
+
import markdown
|
| 10 |
from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple, Union
|
| 11 |
from fastapi import FastAPI, HTTPException, Depends, status
|
| 12 |
from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse
|
|
|
|
| 116 |
use_cache: bool = Field(True)
|
| 117 |
do_sample: bool = Field(True)
|
| 118 |
tokenizer_kwargs: Optional[Dict[str, Any]] = None
|
| 119 |
+
return_only_text: bool = Field(False)
|
| 120 |
max_time: Optional[float] = Field(None, ge=0.0)
|
| 121 |
length_penalty: float = Field(1.0, ge=0.0)
|
| 122 |
no_repeat_ngram_size: int = Field(0, ge=0)
|
|
|
|
| 136 |
length_normalization_factor: Optional[float] = Field(None)
|
| 137 |
min_new_tokens: int = Field(0, ge=0)
|
| 138 |
do_normalize_logits: bool = Field(False)
|
| 139 |
+
return_full_text: bool = Field(False)
|
| 140 |
@validator('stop_sequences')
|
| 141 |
def validate_stop_sequences(cls, v):
|
| 142 |
if v is not None:
|
|
|
|
| 317 |
final_text_raw = final_text_raw.split(stop_seq, 1)[0]
|
| 318 |
break
|
| 319 |
final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
|
| 320 |
+
if req.return_full_text:
|
| 321 |
+
final_text_processed = markdown.markdown(final_text_processed)
|
| 322 |
final_payload: Dict[str, Any] = {
|
| 323 |
"type": "done",
|
| 324 |
"total_prompt_tokens": initial_ids.shape[-1],
|
|
|
|
| 329 |
}
|
| 330 |
yield json.dumps(final_payload) + "\n"
|
| 331 |
except Exception as e:
|
| 332 |
+
if req.return_only_text:
|
| 333 |
+
yield f"Error: {e}\n"
|
| 334 |
+
else:
|
| 335 |
+
error_payload = {"type": "error", "message": str(e)}
|
| 336 |
+
yield json.dumps(error_payload) + "\n"
|
| 337 |
finally:
|
| 338 |
await cleanup()
|
| 339 |
async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]:
|
|
|
|
| 683 |
return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="application/json")
|
| 684 |
else:
|
| 685 |
response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device)
|
| 686 |
+
if req.return_full_text and response_payload.get("generated_sequences"):
|
| 687 |
+
first_sequence = response_payload["generated_sequences"][0].get("text", "")
|
| 688 |
+
markdown_text = markdown.markdown(first_sequence)
|
| 689 |
+
return PlainTextResponse(markdown_text)
|
| 690 |
return JSONResponse(response_payload)
|
| 691 |
except Exception as e:
|
| 692 |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
|