import torch from typing import Dict, Any, List, Union from transformers import VitsModel, AutoTokenizer import numpy as np class EndpointHandler: def __init__(self, path="joselobenitezg/mms-grn-tts", device=None): """Initialize the VITS TTS model and tokenizer. Args: path (str): HuggingFace model path device (str, optional): Device to run the model on ('cuda', 'cpu', or specific cuda device) """ # Device management self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') try: self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = VitsModel.from_pretrained(path).to(self.device) self.sampling_rate = self.model.config.sampling_rate except Exception as e: raise RuntimeError(f"Failed to load model and tokenizer: {str(e)}") # Set maximum input length self.max_input_length = 200 print(f"Model loaded on {self.device}") def validate_input(self, text: Union[str, List[str]]) -> List[str]: """Validate and preprocess input text. Args: text: Input text or list of texts Returns: List[str]: Validated and processed text list Raises: ValueError: If input validation fails """ # Convert single string to list if isinstance(text, str): text = [text] elif isinstance(text, list): if not all(isinstance(t, str) for t in text): raise ValueError("All elements in the input list must be strings") else: raise ValueError("Input must be a string or list of strings") # Validate each text for t in text: if not t.strip(): raise ValueError("Empty text is not allowed") if len(t) > self.max_input_length: raise ValueError(f"Input text exceeds maximum length of {self.max_input_length}") return text def batch_process(self, texts: List[str], batch_size: int = 8) -> List[Dict[str, Any]]: """Process multiple texts in batches. Args: texts (List[str]): List of texts to process batch_size (int): Size of each batch Returns: List[Dict[str, Any]]: List of results for each text """ results = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # Tokenize batch inputs = self.tokenizer(batch_texts, padding=True, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} try: with torch.no_grad(): outputs = self.model(**inputs).waveform for waveform in outputs: # Move to CPU and convert to numpy waveform_np = waveform.cpu().numpy() results.append({ "waveform": waveform_np.tolist(), "sampling_rate": self.sampling_rate }) except Exception as e: raise RuntimeError(f"Error during batch processing: {str(e)}") return results def __call__(self, data: Union[Dict[str, Any], str, List[str]]) -> Union[Dict[str, Any], List[Dict[str, Any]]]: """Process the input text and generate audio. Args: data: Input data in one of these formats: - Dict[str, Any]: {"inputs": "text" or ["text1", "text2"], "batch_size": int} - str: Direct text input - List[str]: List of texts to process Returns: Union[Dict[str, Any], List[Dict[str, Any]]]: Dictionary or list of dictionaries containing the audio waveform(s) and sampling rate """ try: # Handle different input types if isinstance(data, dict): text = data.get("inputs", "") batch_size = data.get("batch_size", 8) elif isinstance(data, (str, list)): text = data batch_size = 8 else: raise ValueError(f"Unsupported input type: {type(data)}") # Validate input texts = self.validate_input(text) # Single input case if len(texts) == 1: inputs = self.tokenizer(texts[0], return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): output = self.model(**inputs).waveform waveform = output.cpu().squeeze().numpy() return { "waveform": waveform.tolist(), "sampling_rate": self.sampling_rate } # Multiple inputs case else: return self.batch_process(texts, batch_size) except Exception as e: error_msg = f"Error processing input: {str(e)}" print(error_msg) # Log the error raise RuntimeError(error_msg) def cleanup(self): """Cleanup resources when shutting down.""" try: # Clear CUDA cache if using GPU if 'cuda' in self.device: torch.cuda.empty_cache() except Exception as e: print(f"Error during cleanup: {str(e)}")