jnjj commited on
Commit
48acb1a
·
verified ·
1 Parent(s): a9504c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
  import json
7
  import time
8
  import logging
 
9
  from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple, Union
10
  from fastapi import FastAPI, HTTPException, Depends, status
11
  from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse
@@ -115,6 +116,7 @@ class GenerateRequest(BaseModel):
115
  use_cache: bool = Field(True)
116
  do_sample: bool = Field(True)
117
  tokenizer_kwargs: Optional[Dict[str, Any]] = None
 
118
  max_time: Optional[float] = Field(None, ge=0.0)
119
  length_penalty: float = Field(1.0, ge=0.0)
120
  no_repeat_ngram_size: int = Field(0, ge=0)
@@ -134,6 +136,7 @@ class GenerateRequest(BaseModel):
134
  length_normalization_factor: Optional[float] = Field(None)
135
  min_new_tokens: int = Field(0, ge=0)
136
  do_normalize_logits: bool = Field(False)
 
137
  @validator('stop_sequences')
138
  def validate_stop_sequences(cls, v):
139
  if v is not None:
@@ -314,6 +317,8 @@ async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tenso
314
  final_text_raw = final_text_raw.split(stop_seq, 1)[0]
315
  break
316
  final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
 
 
317
  final_payload: Dict[str, Any] = {
318
  "type": "done",
319
  "total_prompt_tokens": initial_ids.shape[-1],
@@ -324,8 +329,11 @@ async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tenso
324
  }
325
  yield json.dumps(final_payload) + "\n"
326
  except Exception as e:
327
- error_payload = {"type": "error", "message": str(e)}
328
- yield json.dumps(error_payload) + "\n"
 
 
 
329
  finally:
330
  await cleanup()
331
  async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]:
@@ -675,6 +683,10 @@ async def generate_endpoint(req: GenerateRequest):
675
  return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="application/json")
676
  else:
677
  response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device)
 
 
 
 
678
  return JSONResponse(response_payload)
679
  except Exception as e:
680
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
 
6
  import json
7
  import time
8
  import logging
9
+ import markdown
10
  from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple, Union
11
  from fastapi import FastAPI, HTTPException, Depends, status
12
  from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse
 
116
  use_cache: bool = Field(True)
117
  do_sample: bool = Field(True)
118
  tokenizer_kwargs: Optional[Dict[str, Any]] = None
119
+ return_only_text: bool = Field(False)
120
  max_time: Optional[float] = Field(None, ge=0.0)
121
  length_penalty: float = Field(1.0, ge=0.0)
122
  no_repeat_ngram_size: int = Field(0, ge=0)
 
136
  length_normalization_factor: Optional[float] = Field(None)
137
  min_new_tokens: int = Field(0, ge=0)
138
  do_normalize_logits: bool = Field(False)
139
+ return_full_text: bool = Field(False)
140
  @validator('stop_sequences')
141
  def validate_stop_sequences(cls, v):
142
  if v is not None:
 
317
  final_text_raw = final_text_raw.split(stop_seq, 1)[0]
318
  break
319
  final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
320
+ if req.return_full_text:
321
+ final_text_processed = markdown.markdown(final_text_processed)
322
  final_payload: Dict[str, Any] = {
323
  "type": "done",
324
  "total_prompt_tokens": initial_ids.shape[-1],
 
329
  }
330
  yield json.dumps(final_payload) + "\n"
331
  except Exception as e:
332
+ if req.return_only_text:
333
+ yield f"Error: {e}\n"
334
+ else:
335
+ error_payload = {"type": "error", "message": str(e)}
336
+ yield json.dumps(error_payload) + "\n"
337
  finally:
338
  await cleanup()
339
  async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]:
 
683
  return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="application/json")
684
  else:
685
  response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device)
686
+ if req.return_full_text and response_payload.get("generated_sequences"):
687
+ first_sequence = response_payload["generated_sequences"][0].get("text", "")
688
+ markdown_text = markdown.markdown(first_sequence)
689
+ return PlainTextResponse(markdown_text)
690
  return JSONResponse(response_payload)
691
  except Exception as e:
692
  raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")