import os from typing import Literal, Optional, Tuple import logging import gradio as gr from omegaconf import OmegaConf from dacite import Config as DaciteConfig, from_dict from transformers import GPT2Config, GPT2LMHeadModel from huggingface_hub import PyTorchModelHubMixin, login from llm_trainer import LLMTrainer from xlstm import xLSTMLMModel, xLSTMLMModelConfig login(token=os.getenv('token')) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class xLSTMWrapper(xLSTMLMModel, PyTorchModelHubMixin): pass GPT2_CONFIG = GPT2Config( vocab_size=50304, n_positions=256, n_embd=768, n_layer=12, n_head=12, activation_function="gelu" ) XLSTM_CONFIG = OmegaConf.load("xlstm_config.yaml") XLSTM_CONFIG = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(XLSTM_CONFIG), config=DaciteConfig(strict=True)) UI_CONFIG = { "title": "HSEAI", "description": "Enter your text below and the AI will continue it.", "port": 7860, "host": "0.0.0.0", "default_model": "xLSTM", "max_sequences": 3, "default_length": 64, "min_length": 16, "max_length": 128, "length_step": 16 } xLSTM = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_Vanilla_XLSTM", config=XLSTM_CONFIG) xLSTM_ft = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_Vanilla_XLSTM_FT", config=XLSTM_CONFIG) gpt2 = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2") gpt2_lora = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2") gpt2_lora.load_adapter("AlekMan/HSE_AI_GPT2_LoRA") class ModelManager: """Manages model initialization and caching""" def __init__(self): self._current_trainer: Optional[LLMTrainer] = None self._current_model: Optional[str] = None def get_trainer(self, model_name: Literal["xLSTM", "GPT2", "xLSTM_FT", "GPT2_FT"]): """Get trainer instance, creating if necessary""" if self._current_trainer is None or self._current_model != model_name: logger.info(f"Loading model: {model_name}") self._current_trainer = self._load_model(model_name) self._current_model = model_name logger.info(f"Model {model_name} loaded successfully") return self._current_trainer def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer: """Load and initialize model""" try: if model_name == "GPT2": trainer = LLMTrainer(model=gpt2, model_returns_logits=False) elif model_name == "xLSTM": trainer = LLMTrainer(model=xLSTM, model_returns_logits=True) elif model_name == "GPT2_FT": trainer = LLMTrainer(model=gpt2_lora, model_returns_logits=False) elif model_name == "xLSTM_FT": trainer = LLMTrainer(model=xLSTM_ft, model_returns_logits=True) else: raise ValueError(f"Unsupported model: {model_name}") return trainer except Exception as e: logger.error(f"Failed to load model {model_name}: {e}") raise RuntimeError(f"Failed to load model {model_name}: {e}") model_manager = ModelManager() def generate_text( user_input: str, model_choice: str = UI_CONFIG["default_model"], n_sequences: int = UI_CONFIG["max_sequences"], length: int = UI_CONFIG["default_length"] ) -> Tuple[str, str, str]: """Generate text continuations using the selected model""" if not user_input.strip(): return "Please enter some text first.", "", "" try: logger.info(f"Generating text with {model_choice}, length: {length}") trainer = model_manager.get_trainer(model_choice) continuations = trainer.generate_text( prompt=user_input, n_return_sequences=n_sequences, length=length ) results = [] for i, continuation in enumerate(continuations[:n_sequences]): clean_continuation = continuation[len(user_input):].strip() if clean_continuation: results.append(clean_continuation + "...") else: results.append("(No continuation generated)") while len(results) < 3: results.append("") logger.info("Text generation completed successfully") return results[0], results[1], results[2] except Exception as e: error_msg = f"Error during generation: {str(e)}" logger.error(error_msg) return error_msg, "", "" def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Button]: """Create the input section of the interface""" with gr.Column(): user_input = gr.Textbox( label="Enter your text:", placeholder="Type your text here...", lines=3, max_lines=10 ) with gr.Row(): model_choice = gr.Dropdown( choices=["GPT2", "GPT2_FT", "xLSTM", "xLSTM_FT"], value=UI_CONFIG["default_model"], label="Model", interactive=True ) length = gr.Slider( minimum=UI_CONFIG["min_length"], maximum=UI_CONFIG["max_length"], value=UI_CONFIG["default_length"], step=UI_CONFIG["length_step"], label="Generation Length" ) generate_btn = gr.Button("Generate Continuation", variant="primary") return user_input, model_choice, length, generate_btn def create_output_section() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox]: """Create the output section of the interface""" gr.Markdown("### Generated Continuations:") with gr.Row(): output1 = gr.Textbox( label="Continuation 1", lines=8, max_lines=15, interactive=False ) output2 = gr.Textbox( label="Continuation 2", lines=8, max_lines=15, interactive=False ) output3 = gr.Textbox( label="Continuation 3", lines=8, max_lines=15, interactive=False ) return output1, output2, output3 def setup_event_handlers( user_input: gr.Textbox, model_choice: gr.Dropdown, length: gr.Slider, generate_btn: gr.Button, outputs: Tuple[gr.Textbox, gr.Textbox, gr.Textbox] ) -> None: """Setup event handlers for the interface""" inputs = [ user_input, model_choice, gr.Number(value=UI_CONFIG["max_sequences"], visible=False), length ] generate_btn.click( fn=generate_text, inputs=inputs, outputs=list(outputs) ) user_input.submit( fn=generate_text, inputs=inputs, outputs=list(outputs) ) def create_interface() -> gr.Blocks: """Create and return the Gradio interface""" with gr.Blocks(title=UI_CONFIG["title"], theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {UI_CONFIG['title']}") gr.Markdown(UI_CONFIG["description"]) with gr.Row(): user_input, model_choice, length, generate_btn = create_input_section() outputs = create_output_section() setup_event_handlers(user_input, model_choice, length, generate_btn, outputs) return demo def initialize_model_on_startup() -> None: """Initialize the default model on startup""" try: logger.info(f"Initializing {UI_CONFIG['default_model']} model on startup...") model_manager.get_trainer(UI_CONFIG["default_model"]) logger.info(f"{UI_CONFIG['default_model']} model initialized successfully!") except Exception as e: logger.warning(f"Could not initialize model on startup: {e}") logger.info("Model will be initialized when first used.") def main() -> None: """Main function to launch the Gradio app""" logger.info(f"Starting {UI_CONFIG['title']} application...") initialize_model_on_startup() demo = create_interface() logger.info(f"Launching interface on {UI_CONFIG['host']}:{UI_CONFIG['port']}") demo.launch( server_name=UI_CONFIG["host"], server_port=UI_CONFIG["port"], share=False, show_error=True ) if __name__ == "__main__": main()