Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import spaces | |
| import re | |
| # Model configuration | |
| model_name = "HelpingAI/Dhanishtha-2.0-preview" | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| """Load the model and tokenizer""" | |
| global model, tokenizer | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Ensure pad token is set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| print("Model loaded successfully!") | |
| def format_thinking_text(text): | |
| """Format text to properly display <think> and <ser> tags in Gradio with styled borders""" | |
| if not text: | |
| return text | |
| # More sophisticated formatting for thinking and ser blocks | |
| formatted_text = text | |
| # Handle thinking blocks with blue styling | |
| thinking_pattern = r'<think>(.*?)</think>' | |
| def replace_thinking_block(match): | |
| thinking_content = match.group(1).strip() | |
| return f''' | |
| <div style="border-left: 4px solid #4a90e2; background: linear-gradient(135deg, #f0f8ff 0%, #e6f3ff 100%); padding: 16px 20px; margin: 16px 0; border-radius: 12px; font-family: 'Segoe UI', sans-serif; box-shadow: 0 2px 8px rgba(74, 144, 226, 0.15); border: 1px solid rgba(74, 144, 226, 0.2);"> | |
| <div style="color: #4a90e2; font-weight: 600; margin-bottom: 10px; display: flex; align-items: center; font-size: 14px;"> | |
| <span style="margin-right: 8px;">π§ </span> Think | |
| </div> | |
| <div style="color: #2c3e50; line-height: 1.6; font-size: 14px;"> | |
| {thinking_content} | |
| </div> | |
| </div> | |
| ''' | |
| # Handle ser blocks with green styling | |
| ser_pattern = r'<ser>(.*?)</ser>' | |
| def replace_ser_block(match): | |
| ser_content = match.group(1).strip() | |
| return f''' | |
| <div style="border-left: 4px solid #28a745; background: linear-gradient(135deg, #f0fff4 0%, #e6ffed 100%); padding: 16px 20px; margin: 16px 0; border-radius: 12px; font-family: 'Segoe UI', sans-serif; box-shadow: 0 2px 8px rgba(40, 167, 69, 0.15); border: 1px solid rgba(40, 167, 69, 0.2);"> | |
| <div style="color: #28a745; font-weight: 600; margin-bottom: 10px; display: flex; align-items: center; font-size: 14px;"> | |
| <span style="margin-right: 8px;">π</span> Ser | |
| </div> | |
| <div style="color: #155724; line-height: 1.6; font-size: 14px;"> | |
| {ser_content} | |
| </div> | |
| </div> | |
| ''' | |
| # Apply both patterns | |
| formatted_text = re.sub(thinking_pattern, replace_thinking_block, formatted_text, flags=re.DOTALL) | |
| formatted_text = re.sub(ser_pattern, replace_ser_block, formatted_text, flags=re.DOTALL) | |
| # Clean up any remaining raw tags | |
| formatted_text = re.sub(r'</?(?:think|ser)>', '', formatted_text) | |
| return formatted_text.strip() | |
| def generate_response(message, history, max_tokens, temperature, top_p): | |
| """Generate streaming response without threading""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| yield "Model is still loading. Please wait..." | |
| return | |
| # Prepare conversation history | |
| messages = [] | |
| # Handle both old tuple format and new message format | |
| for item in history: | |
| if isinstance(item, dict): | |
| # New message format | |
| messages.append(item) | |
| elif isinstance(item, (list, tuple)) and len(item) == 2: | |
| # Old tuple format | |
| user_msg, assistant_msg = item | |
| messages.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # Add current message | |
| messages.append({"role": "user", "content": message}) | |
| # Apply chat template | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Tokenize input | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| try: | |
| with torch.no_grad(): | |
| # Use transformers streaming with custom approach | |
| generated_text = "" | |
| current_input_ids = model_inputs["input_ids"] | |
| current_attention_mask = model_inputs["attention_mask"] | |
| for _ in range(max_tokens): | |
| # Generate next token | |
| outputs = model( | |
| input_ids=current_input_ids, | |
| attention_mask=current_attention_mask, | |
| use_cache=True | |
| ) | |
| # Get logits for the last token | |
| logits = outputs.logits[0, -1, :] | |
| # Apply temperature | |
| if temperature != 1.0: | |
| logits = logits / temperature | |
| # Apply top-p sampling | |
| if top_p < 1.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() | |
| sorted_indices_to_remove[0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| logits[indices_to_remove] = float('-inf') | |
| # Sample next token | |
| probs = torch.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Check for EOS token | |
| if next_token.item() == tokenizer.eos_token_id: | |
| break | |
| # Decode the new token (preserve special tokens like <think>) | |
| new_token_text = tokenizer.decode(next_token, skip_special_tokens=False) | |
| generated_text += new_token_text | |
| # Format and yield the current text | |
| formatted_text = format_thinking_text(generated_text) | |
| yield formatted_text | |
| # Update inputs for next iteration | |
| current_input_ids = torch.cat([current_input_ids, next_token.unsqueeze(0)], dim=-1) | |
| current_attention_mask = torch.cat([current_attention_mask, torch.ones((1, 1), device=model.device)], dim=-1) | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}" | |
| return | |
| # Final yield with complete formatted text | |
| final_text = format_thinking_text(generated_text) if generated_text else "No response generated." | |
| yield final_text | |
| def chat_interface(message, history, max_tokens, temperature, top_p): | |
| """Main chat interface with improved streaming""" | |
| if not message.strip(): | |
| return history, "" | |
| # Add user message to history in the new message format | |
| history.append({"role": "user", "content": message}) | |
| # Add placeholder for assistant response | |
| history.append({"role": "assistant", "content": ""}) | |
| # Generate response with streaming | |
| for partial_response in generate_response(message, history[:-2], max_tokens, temperature, top_p): | |
| history[-1]["content"] = partial_response | |
| yield history, "" | |
| return history, "" | |
| # Load model on startup | |
| print("Initializing model...") | |
| load_model() | |
| # Minimal CSS - only for think and ser blocks | |
| custom_css = """ | |
| /* Only essential styling for think and ser blocks */ | |
| .chatbot { | |
| font-family: system-ui, -apple-system, sans-serif; | |
| } | |
| """ | |
| # Create advanced Gradio interface with professional design | |
| with gr.Blocks( | |
| title="οΏ½ Dhanishtha-2.0-preview | Advanced Reasoning AI", | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| head=""" | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <meta name="description" content="Chat with Dhanishtha-2.0-preview - The world's first LLM with multi-step reasoning capabilities"> | |
| """ | |
| ) as demo: | |
| # Simple Header | |
| gr.Markdown( | |
| """ | |
| # π§ Dhanishtha-2.0-preview Chat | |
| Chat with the **HelpingAI/Dhanishtha-2.0-preview** model - Advanced Reasoning AI with Multi-Step Thinking | |
| ### Features: | |
| - π§ **Think Blocks**: Internal reasoning process (blue styling) | |
| - π **Ser Blocks**: Emotional understanding (green styling) | |
| - β‘ **Real-time Streaming**: Token-by-token generation | |
| - π― **Step-by-step Solutions**: Detailed problem solving | |
| """ | |
| ) | |
| # Main Chat Interface | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot( | |
| [], | |
| elem_id="chatbot", | |
| type='messages', | |
| height=600, | |
| show_copy_button=True, | |
| show_share_button=True, | |
| avatar_images=("π€", "π€"), | |
| render_markdown=True, | |
| sanitize_html=False, # Allow HTML for thinking and ser blocks | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "$", "right": "$", "display": False} | |
| ] | |
| ) | |
| # Simple input section | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| container=False, | |
| placeholder="Ask me anything! I'll show you my thinking and reasoning process...", | |
| label="Message", | |
| autofocus=True, | |
| lines=1, | |
| max_lines=3, | |
| scale=7 | |
| ) | |
| send_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear", variant="secondary", scale=1) | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### βοΈ Generation Parameters") | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=8192, | |
| value=2048, | |
| step=50, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.05, | |
| label="Top-p", | |
| info="Nucleus sampling threshold" | |
| ) | |
| gr.Markdown("### π Model Info") | |
| gr.Markdown( | |
| """ | |
| **Model**: HelpingAI/Dhanishtha-2.0-preview | |
| **Type**: Reasoning LLM with thinking blocks | |
| **Features**: Multi-step reasoning, self-evaluation | |
| **Blocks**: Think (blue) + Ser (green) | |
| """ | |
| ) | |
| # Examples Section | |
| gr.Examples( | |
| examples=[ | |
| ["Solve this step by step: What is 15% of 240?"], | |
| ["How many letter 'r' are in the words 'strawberry' and 'raspberry'?"], | |
| ["Hello! Can you introduce yourself and show me how you think?"], | |
| ["Explain quantum entanglement in simple terms"], | |
| ["Write a Python function to find the factorial of a number"], | |
| ["What are the pros and cons of renewable energy?"], | |
| ["What's the difference between AI and machine learning?"], | |
| ["Create a haiku about artificial intelligence"], | |
| ["Why is the sky blue? Explain using physics principles"], | |
| ["Compare bubble sort and quick sort algorithms"] | |
| ], | |
| inputs=msg, | |
| label="Example Prompts - Try these to see the thinking process!", | |
| examples_per_page=5 | |
| ) | |
| # Event handlers | |
| def clear_chat(): | |
| """Clear the chat history""" | |
| return [], "" | |
| # Message submission events | |
| msg.submit( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
| outputs=[chatbot, msg], | |
| concurrency_limit=1, | |
| show_progress="minimal" | |
| ) | |
| send_btn.click( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
| outputs=[chatbot, msg], | |
| concurrency_limit=1, | |
| show_progress="minimal" | |
| ) | |
| # Clear chat event | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, msg], | |
| show_progress=False | |
| ) | |
| # Simple Footer | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### π§ Technical Details | |
| - **Model**: HelpingAI/Dhanishtha-2.0-preview | |
| - **Reasoning**: Multi-step thinking with `<think>` and `<ser>` blocks | |
| **Note**: This interface streams responses token by token and formats thinking blocks for better readability. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| # Launch with enhanced configuration | |
| demo.queue( | |
| max_size=20, | |
| default_concurrency_limit=1 | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| quiet=False | |
| ) |