Spaces:
Runtime error
Runtime error
| import asyncio | |
| import gradio as gr | |
| import transformers | |
| from transformers import ( | |
| TextIteratorStreamer, | |
| AutoTokenizer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| ) | |
| import threading | |
| import ctypes | |
| tokenizer = AutoTokenizer.from_pretrained("pfnet/plamo-2-1b", trust_remote_code=True) | |
| pipeline = transformers.pipeline( | |
| "text-generation", | |
| model="pfnet/plamo-2-1b", | |
| trust_remote_code=True, | |
| ) | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops=[], encounters=1): | |
| super().__init__() | |
| self.stops = stops | |
| def __call__(self, input_ids, scores): | |
| last_token = input_ids[0][-2:] | |
| for stop in self.stops: | |
| if stop in tokenizer.decode(last_token): | |
| return True | |
| return False | |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=["\n\n"])]) | |
| class CancelableThread(threading.Thread): | |
| def __init__(self, group=None, target=None, name=None, args=(), kwargs={}): | |
| threading.Thread.__init__(self, group=group, target=target, name=name) | |
| self.args = args | |
| self.kwargs = kwargs | |
| return | |
| def run(self): | |
| self.id = threading.get_native_id() | |
| self._target(*self.args, **self.kwargs) | |
| def get_id(self): | |
| return self.id | |
| def raise_exception(self): | |
| thread_id = self.get_id() | |
| resu = ctypes.pythonapi.PyThreadState_SetAsyncExc( | |
| ctypes.c_long(thread_id), ctypes.py_object(SystemExit) | |
| ) | |
| if resu > 1: | |
| ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(thread_id), 0) | |
| print("Failure in raising exception") | |
| class ThreadManager: | |
| def __init__(self, thread: CancelableThread, **kwargs): | |
| self.thread = thread | |
| def __enter__(self): | |
| # スレッドを開始 | |
| self.thread.start() | |
| return self.thread | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| # スレッドの終了を待機 | |
| if self.thread.is_alive(): | |
| print("trying to terminate thread") | |
| self.thread.raise_exception() | |
| self.thread.join() | |
| print("Thread has been successfully joined.") | |
| def respond(prompt, max_tokens): | |
| # print(prompt) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| thread = CancelableThread( | |
| target=pipeline, | |
| kwargs=dict( | |
| text_inputs=prompt, | |
| max_new_tokens=max_tokens, | |
| return_full_text=False, | |
| streamer=streamer, | |
| pad_token_id=tokenizer.pad_token_id, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| stopping_criteria=stopping_criteria, | |
| ), | |
| ) | |
| response = "" | |
| with ThreadManager(thread=thread): | |
| for output in streamer: | |
| if not output: | |
| continue | |
| # print(output) | |
| response += output | |
| yield response, gr.update(interactive=False), gr.update(interactive=False), | |
| yield ( | |
| response, | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| ) | |
| def reset_textbox(): | |
| return gr.update(value=""), gr.update(value="") | |
| def no_interactive(): | |
| return gr.update(interactive=False), gr.update(interactive=False) | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">plamo-2-1b CPU demo</h1>""") | |
| gr.Markdown( | |
| "2 vCPU, 16 GB RAMでのデモです。10年前くらいのノートパソコンくらい。(GPUなしのHugging Faceの無料インスタンスで動いています。)vllmとかllama.cppが対応すればもっと高速に動くはず。" | |
| ) | |
| with gr.Column(elem_id="col_container") as main_block: | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| lines=15, label="input_text", placeholder="これからの人工知能技術は" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(scale=5): | |
| submit_button = gr.Button("Submit") | |
| outputs = gr.Textbox(lines=20, label="Output") | |
| # inputs, top_p, temperature, top_k, repetition_penalty | |
| with gr.Accordion("Parameters", open=False): | |
| max_tokens = gr.Slider( | |
| minimum=1, maximum=4096, value=32, step=1, label="Max new tokens" | |
| ) | |
| submit_button.click(no_interactive, [], [submit_button, clear_button]) | |
| submit_button.click( | |
| respond, | |
| [input_text, max_tokens], | |
| [outputs, submit_button, clear_button], | |
| ) | |
| clear_button.click(reset_textbox, [], [input_text, outputs], queue=False) | |
| demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860) | |
| if __name__ == "__main__": | |
| demo.launch() | |