Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Iterable | |
| import gradio as gr | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import time | |
| import torch | |
| from transformers import pipeline | |
| import pandas as pd | |
| instruct_pipeline = pipeline(model="databricks/dolly-v2-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") | |
| def run_pipeline(prompt): | |
| response = instruct_pipeline(prompt) | |
| return response | |
| def get_user_input(input_question, history): | |
| return "", history + [[input_question, None]] | |
| def get_qa_user_input(input_question, history): | |
| return "", history + [[input_question, None]] | |
| def dolly_chat(history): | |
| prompt = history[-1][0] | |
| bot_message = run_pipeline(prompt) | |
| history[-1][1] = bot_message | |
| return history | |
| def qa_bot(context, history): | |
| query = history[-1][0] | |
| prompt = f'instruction: {query} \ncontext: {context}' | |
| bot_message = run_pipeline(prompt) | |
| history[-1][1] = bot_message | |
| return history | |
| def reset_chatbot(): | |
| return gr.update(value="") | |
| def load_customer_support_example(): | |
| df = pd.read_csv("examples.csv") | |
| return df['doc'].iloc[0], df['question'].iloc[0] | |
| def load_databricks_doc_example(): | |
| df = pd.read_csv("examples.csv") | |
| return df['doc'].iloc[1], df['question'].iloc[1] | |
| # Referred & modified from https://gradio.app/theming-guide/ | |
| class SeafoamCustom(Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.emerald, | |
| secondary_hue: colors.Color | str = colors.blue, | |
| neutral_hue: colors.Color | str = colors.blue, | |
| spacing_size: sizes.Size | str = sizes.spacing_md, | |
| radius_size: sizes.Size | str = sizes.radius_md, | |
| font: fonts.Font | |
| | str | |
| | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Quicksand"), | |
| "ui-sans-serif", | |
| "sans-serif", | |
| ), | |
| font_mono: fonts.Font | |
| | str | |
| | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)", | |
| button_primary_text_color="white", | |
| button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)", | |
| block_shadow="*shadow_drop_lg", | |
| button_shadow="*shadow_drop_lg", | |
| input_background_fill="zinc", | |
| input_border_color="*secondary_300", | |
| input_shadow="*shadow_drop", | |
| input_shadow_focus="*shadow_drop_lg", | |
| ) | |
| seafoam = SeafoamCustom() | |
| with gr.Blocks(theme=seafoam) as demo: | |
| with gr.Row(variant='panel'): | |
| with gr.Column(): | |
| gr.HTML( | |
| """<html><img src='file/dolly.jpg', alt='dolly logo', width=150, height=150 /><br></html>""" | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("# **<p align='center'>Dolly 2.0: World's First Truly Open Instruction-Tuned LLM</p>**") | |
| gr.Markdown("Dolly 2.0, the first open source, instruction-following LLM, fine-tuned on a human-generated instruction dataset licensed for research and commercial use. It's a 12B parameter language model based on the EleutherAI pythia model family and fine-tuned exclusively on a new, high-quality human generated instruction following dataset, crowdsourced among Databricks employees.") | |
| qa_bot_state = gr.State(value=[]) | |
| with gr.Tabs(): | |
| with gr.TabItem("Dolly Chat"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| chatbot = gr.Chatbot(label="Chat History") | |
| input_question = gr.Text( | |
| label="Instruction", | |
| placeholder="Type prompt and hit enter.", | |
| ) | |
| clear = gr.Button("Clear", variant="primary") | |
| with gr.Row(): | |
| with gr.Accordion("Show example inputs I can load:", open=False): | |
| gr.Examples( | |
| [ | |
| ["Explain to me the difference between nuclear fission and fusion."], | |
| ["Give me a list of 5 science fiction books I should read next."], | |
| ["I'm selling my Nikon D-750, write a short blurb for my ad."], | |
| ["Write a song about sour donuts"], | |
| ["Write a tweet about a new book launch by J.K. Rowling."], | |
| ], | |
| [input_question], | |
| [], | |
| None, | |
| cache_examples=False, | |
| ) | |
| with gr.TabItem("Q&A with Context"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_context = gr.Text(label="Add context here", lines=10) | |
| with gr.Column(): | |
| qa_chatbot = gr.Chatbot(label="Q&A History") | |
| qa_input_question = gr.Text( | |
| label="Input Question", | |
| placeholder="Type question here and hit enter.", | |
| ) | |
| qa_clear = gr.Button("Clear", variant="primary") | |
| with gr.Row(): | |
| with gr.Accordion("Show example inputs I can load:", open=False): | |
| example_1 = gr.Button("Load Customer support example") | |
| example_2 = gr.Button("Load Databricks documentation example") | |
| input_question.submit( | |
| get_user_input, | |
| [input_question, chatbot], | |
| [input_question, chatbot], | |
| ).then(dolly_chat, [chatbot], chatbot) | |
| clear.click(lambda: None, None, chatbot) | |
| qa_input_question.submit( | |
| get_qa_user_input, | |
| [qa_input_question, qa_chatbot], | |
| [qa_input_question, qa_chatbot], | |
| ).then(qa_bot, [input_context, qa_chatbot], qa_chatbot) | |
| qa_clear.click(lambda: None, None, qa_chatbot) | |
| # reset the chatbot Q&A history when input context changes | |
| input_context.change(fn=reset_chatbot, inputs=[], outputs=qa_chatbot) | |
| example_1.click( | |
| load_customer_support_example, | |
| [], | |
| [input_context, qa_input_question], | |
| ) | |
| example_2.click( | |
| load_databricks_doc_example, | |
| [], | |
| [input_context, qa_input_question], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=1,max_size=100).launch(max_threads=5,debug=True) | |