""" Custom Handler for MORBID-Actuarial v0.1.0 Conversational Model Hugging Face Inference Endpoints """ from typing import Dict, List, Any import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the handler with model and tokenizer Args: path: Path to the model directory """ # Load tokenizer and model # Some repos may have a non-standard model_type. In that case, fall back to a known base model. fallback_model_id = os.getenv("BASE_MODEL_ID", "TinyLlama/TinyLlama-1.1B-Chat-v1.0") dtype = torch.float16 if torch.cuda.is_available() else None try: self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=dtype, device_map="auto", low_cpu_mem_usage=True ) except Exception: # Fallback to a supported base model self.tokenizer = AutoTokenizer.from_pretrained(fallback_model_id, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( fallback_model_id, torch_dtype=dtype, device_map="auto", low_cpu_mem_usage=True ) # Set padding token if not already set if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # System prompt for conversational behavior self.system_prompt = """You are MORBID.AI, a friendly and conversational actuarial assistant. You have expertise in: - Life expectancy and mortality statistics - Insurance and risk calculations - Financial mathematics (FM exam - 100% accuracy) - Probability theory (P exam - 100% accuracy) - Investment and financial markets (IFM exam - 93.3% accuracy) Be warm, helpful, and engaging. Respond naturally to greetings and casual conversation while maintaining your actuarial expertise. When users greet you, respond warmly. When they ask for help, be supportive and clear. Balance personality with precision when discussing technical topics.""" def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process the inference request Args: data: Dictionary containing the input data - inputs (str or list): The input text(s) - parameters (dict): Generation parameters Returns: List of generated responses """ # Extract inputs inputs = data.get("inputs", "") parameters = data.get("parameters", {}) # Handle both string and list inputs if isinstance(inputs, str): inputs = [inputs] elif not isinstance(inputs, list): inputs = [str(inputs)] # Set default generation parameters generation_params = { "max_new_tokens": parameters.get("max_new_tokens", 200), "temperature": parameters.get("temperature", 0.8), "top_p": parameters.get("top_p", 0.95), "do_sample": parameters.get("do_sample", True), "repetition_penalty": parameters.get("repetition_penalty", 1.1), "pad_token_id": self.tokenizer.pad_token_id, "eos_token_id": self.tokenizer.eos_token_id, } # Process each input results = [] for input_text in inputs: # Format the prompt with conversational context prompt = self._format_prompt(input_text) # Tokenize inputs_tokenized = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(self.model.device) # Generate response # Prepare additional decoding constraints bad_words_ids = [] try: # Disallow role-tag leakage in generations role_tokens = ["Human:", "User:", "Assistant:", "SYSTEM:", "System:"] tokenized = self.tokenizer(role_tokens, add_special_tokens=False).input_ids # input_ids can be nested lists (one per tokenized string) for ids in tokenized: if isinstance(ids, list) and len(ids) > 0: bad_words_ids.append(ids) except Exception: pass decoding_kwargs = { **generation_params, # Encourage coherence and reduce repetition/artifacts "no_repeat_ngram_size": 3, } if bad_words_ids: decoding_kwargs["bad_words_ids"] = bad_words_ids with torch.no_grad(): outputs = self.model.generate( **inputs_tokenized, **decoding_kwargs ) # Decode the response generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response and trim at stop sequences response = self._extract_response(generated_text, prompt) response = self._truncate_at_stops(response) results.append({ "generated_text": response, "conversation": { "user": input_text, "assistant": response } }) return results def _format_prompt(self, user_input: str) -> str: """ Format the user input into a conversational prompt Args: user_input: The user's message Returns: Formatted prompt string """ # Check if it's a greeting or casual message lower_input = user_input.lower().strip() # For very short inputs or greetings, add conversational context if len(lower_input) <= 20 or any(greet in lower_input for greet in ["hi", "hello", "hey", "howdy"]): return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: " # For longer inputs, check if they're actuarial actuarial_keywords = ["mortality", "life expectancy", "insurance", "premium", "annuity", "probability", "risk", "actuarial", "death", "survival"] if any(keyword in lower_input for keyword in actuarial_keywords): # Actuarial query - be precise but friendly return f"As a conversational actuarial AI assistant, provide a helpful and accurate response.\n\nHuman: {user_input}\nAssistant: " else: # General conversation - be more casual return f"{self.system_prompt}\n\nHuman: {user_input}\nAssistant: " def _extract_response(self, generated_text: str, prompt: str) -> str: """ Extract only the assistant's response from the generated text Args: generated_text: Full generated text including prompt prompt: The original prompt Returns: Just the assistant's response """ # Strategy: take everything after the LAST "Assistant:" marker; fallback to stripping prompt if "Assistant:" in generated_text: response = generated_text.split("Assistant:")[-1].strip() elif generated_text.startswith(prompt): response = generated_text[len(prompt):].strip() else: response = generated_text.strip() # Clean up any remaining markers if response.startswith(":"): response = response[1:].strip() # Ensure we have a response if not response: response = "I'm here to help! Could you please rephrase your question?" return response def _truncate_at_stops(self, text: str) -> str: """Truncate model output at conversation stop markers to avoid echoing future turns.""" stop_markers = ["\nHuman:", "\nUser:", "\nAssistant:", "\nSYSTEM:", "\nSystem:"] cut_index = None for marker in stop_markers: idx = text.find(marker) if idx != -1: cut_index = idx if cut_index is None else min(cut_index, idx) if cut_index is not None: text = text[:cut_index].rstrip() # Keep response reasonably bounded if len(text) > 2000: text = text[:2000].rstrip() return text