Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForSequenceClassification | |
| import random | |
| import torch | |
| import groq # Assuming you are using the Groq library | |
| import os | |
| from dotenv import load_dotenv | |
| from huggingface_hub import login | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| HUGGING_FACE_TOKEN = os.getenv("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") | |
| # Authenticate with Hugging Face (use your token) | |
| login(HUGGING_FACE_TOKEN) | |
| # Load the mental health counseling conversations dataset | |
| ds = load_dataset("Amod/mental_health_counseling_conversations") | |
| context = ds["train"]["Context"] | |
| response = ds["train"]["Response"] | |
| GROQ_API_KEY = "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" | |
| client = groq.Groq(api_key=GROQ_API_KEY) | |
| # Load FLAN-T5 model and tokenizer for primary RAG | |
| flan_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small") | |
| flan_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small") | |
| # Load sentiment analysis model | |
| sentiment_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| sentiment_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| # Groq client setup (assuming you have an API key) | |
| client = groq.Groq(api_key=GROQ_API_KEY) # Corrected Groq client initialization | |
| # Function for sentiment analysis | |
| def analyze_sentiment(text): | |
| inputs = sentiment_tokenizer(text, return_tensors="pt") | |
| outputs = sentiment_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| sentiment = "positive" if torch.argmax(probs) == 1 else "negative" | |
| confidence = probs.max().item() | |
| return sentiment, confidence | |
| # Function to generate response based on sentiment and user input | |
| def generate_response(sentiment, user_input): | |
| prompt = f"The user feels {sentiment}. Respond with supportive advice based on: {user_input}" | |
| inputs = flan_tokenizer(prompt, return_tensors="pt") | |
| response = flan_model.generate(**inputs, max_length=150) | |
| return flan_tokenizer.decode(response[0], skip_special_tokens=True) | |
| # Main chatbot function | |
| def chatbot(user_input): | |
| if not user_input.strip(): | |
| return "Please enter a question or concern to receive guidance." | |
| # Word count limit | |
| word_count = len(user_input.split()) | |
| max_words = 50 | |
| remaining_words = max_words - word_count | |
| if remaining_words < 0: | |
| return f"Your input is too long. Please limit it to {max_words} words." | |
| # Sentiment analysis | |
| sentiment, confidence = analyze_sentiment(user_input) | |
| # Groq API fallback for a personalized response | |
| try: | |
| brief_response = client.chat.completions.create( | |
| messages=[{ | |
| "role": "user", | |
| "content": user_input, | |
| }], | |
| model="llama3-8b-8192", # Change model if needed | |
| ) | |
| brief_response = brief_response.choices[0].message.content | |
| except Exception as e: | |
| brief_response = None | |
| if brief_response: | |
| return f"**Personalized Response from Groq:** {brief_response}" | |
| # Fallback to FLAN-T5 model for response generation | |
| response_text = generate_response(sentiment, user_input) | |
| def generate_response(user_input): | |
| # Generate response using FLAN-T5 | |
| inputs = flan_tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True) | |
| summary_ids = flan_model.generate(inputs, max_length=100, num_beams=4, early_stopping=True) | |
| generated_response = flan_tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| if not generated_response: | |
| return "I'm sorry, I don't have information specific to your concern. Please consult a professional." | |
| # Final response with different sources | |
| complete_response = ( | |
| f"**Sentiment Analysis:** {sentiment} (Confidence: {confidence:.2f})\n\n" | |
| f"**Generated Response:**\n{generated_response}\n\n" | |
| f"**Contextual Information:**\n{context_text}\n\n" | |
| f"**Additional Dataset Response:**\n{dataset_response}\n\n" | |
| f"Words entered: {word_count}, Words remaining: {remaining_words}" | |
| ) | |
| return complete_response | |
| # Example call to the function | |
| response = generate_response("This is an example input.") | |
| print(response) | |
| # Set up Gradio interface | |
| interface = gr.Interface( | |
| fn=chatbot, | |
| inputs=gr.Textbox( | |
| label="Ask your question:", | |
| placeholder="Describe how you're feeling today...", | |
| lines=4 | |
| ), | |
| outputs=gr.Markdown(label="Psychologist Assistant Response"), | |
| title="Virtual Psychiatrist Assistant", | |
| description="Enter your mental health concerns, and receive guidance and responses from a trained assistant.", | |
| theme="huggingface" | |
| ) | |
| # Launch the app | |
| interface.launch() | |