| import os | |
| from typing import Literal, Optional, Tuple | |
| import logging | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| from dacite import Config as DaciteConfig, from_dict | |
| from transformers import GPT2Config, GPT2LMHeadModel | |
| from huggingface_hub import PyTorchModelHubMixin, login | |
| from llm_trainer import LLMTrainer | |
| from xlstm import xLSTMLMModel, xLSTMLMModelConfig | |
| login(token=os.getenv('token')) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class xLSTMWrapper(xLSTMLMModel, PyTorchModelHubMixin): | |
| pass | |
| GPT2_CONFIG = GPT2Config( | |
| vocab_size=50304, | |
| n_positions=256, | |
| n_embd=768, | |
| n_layer=12, | |
| n_head=12, | |
| activation_function="gelu" | |
| ) | |
| XLSTM_CONFIG = OmegaConf.load("xlstm_config.yaml") | |
| XLSTM_CONFIG = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(XLSTM_CONFIG), config=DaciteConfig(strict=True)) | |
| UI_CONFIG = { | |
| "title": "HSEAI", | |
| "description": "Enter your text below and the AI will continue it.", | |
| "port": 7860, | |
| "host": "0.0.0.0", | |
| "default_model": "xLSTM", | |
| "max_sequences": 3, | |
| "default_length": 64, | |
| "min_length": 16, | |
| "max_length": 128, | |
| "length_step": 16 | |
| } | |
| xLSTM = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_Vanilla_XLSTM", config=XLSTM_CONFIG) | |
| xLSTM_ft = xLSTMWrapper(XLSTM_CONFIG).from_pretrained("AlekMan/HSE_AI_Vanilla_XLSTM_FT", config=XLSTM_CONFIG) | |
| gpt2 = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2") | |
| gpt2_lora = GPT2LMHeadModel(GPT2_CONFIG).from_pretrained("AlekMan/HSE_AI_GPT2") | |
| gpt2_lora.load_adapter("AlekMan/HSE_AI_GPT2_LoRA") | |
| class ModelManager: | |
| """Manages model initialization and caching""" | |
| def __init__(self): | |
| self._current_trainer: Optional[LLMTrainer] = None | |
| self._current_model: Optional[str] = None | |
| def get_trainer(self, model_name: Literal["xLSTM", "GPT2", "xLSTM_FT", "GPT2_FT"]): | |
| """Get trainer instance, creating if necessary""" | |
| if self._current_trainer is None or self._current_model != model_name: | |
| logger.info(f"Loading model: {model_name}") | |
| self._current_trainer = self._load_model(model_name) | |
| self._current_model = model_name | |
| logger.info(f"Model {model_name} loaded successfully") | |
| return self._current_trainer | |
| def _load_model(self, model_name: Literal["xLSTM", "GPT2"]) -> LLMTrainer: | |
| """Load and initialize model""" | |
| try: | |
| if model_name == "GPT2": | |
| trainer = LLMTrainer(model=gpt2, model_returns_logits=False) | |
| elif model_name == "xLSTM": | |
| trainer = LLMTrainer(model=xLSTM, model_returns_logits=True) | |
| elif model_name == "GPT2_FT": | |
| trainer = LLMTrainer(model=gpt2_lora, model_returns_logits=False) | |
| elif model_name == "xLSTM_FT": | |
| trainer = LLMTrainer(model=xLSTM_ft, model_returns_logits=True) | |
| else: | |
| raise ValueError(f"Unsupported model: {model_name}") | |
| return trainer | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_name}: {e}") | |
| raise RuntimeError(f"Failed to load model {model_name}: {e}") | |
| model_manager = ModelManager() | |
| def generate_text( | |
| user_input: str, | |
| model_choice: str = UI_CONFIG["default_model"], | |
| n_sequences: int = UI_CONFIG["max_sequences"], | |
| length: int = UI_CONFIG["default_length"] | |
| ) -> Tuple[str, str, str]: | |
| """Generate text continuations using the selected model""" | |
| if not user_input.strip(): | |
| return "Please enter some text first.", "", "" | |
| try: | |
| logger.info(f"Generating text with {model_choice}, length: {length}") | |
| trainer = model_manager.get_trainer(model_choice) | |
| continuations = trainer.generate_text( | |
| prompt=user_input, | |
| n_return_sequences=n_sequences, | |
| length=length | |
| ) | |
| results = [] | |
| for i, continuation in enumerate(continuations[:n_sequences]): | |
| clean_continuation = continuation[len(user_input):].strip() | |
| if clean_continuation: | |
| results.append(clean_continuation + "...") | |
| else: | |
| results.append("(No continuation generated)") | |
| while len(results) < 3: | |
| results.append("") | |
| logger.info("Text generation completed successfully") | |
| return results[0], results[1], results[2] | |
| except Exception as e: | |
| error_msg = f"Error during generation: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, "", "" | |
| def create_input_section() -> Tuple[gr.Textbox, gr.Dropdown, gr.Slider, gr.Button]: | |
| """Create the input section of the interface""" | |
| with gr.Column(): | |
| user_input = gr.Textbox( | |
| label="Enter your text:", | |
| placeholder="Type your text here...", | |
| lines=3, | |
| max_lines=10 | |
| ) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=["GPT2", "GPT2_FT", "xLSTM", "xLSTM_FT"], | |
| value=UI_CONFIG["default_model"], | |
| label="Model", | |
| interactive=True | |
| ) | |
| length = gr.Slider( | |
| minimum=UI_CONFIG["min_length"], | |
| maximum=UI_CONFIG["max_length"], | |
| value=UI_CONFIG["default_length"], | |
| step=UI_CONFIG["length_step"], | |
| label="Generation Length" | |
| ) | |
| generate_btn = gr.Button("Generate Continuation", variant="primary") | |
| return user_input, model_choice, length, generate_btn | |
| def create_output_section() -> Tuple[gr.Textbox, gr.Textbox, gr.Textbox]: | |
| """Create the output section of the interface""" | |
| gr.Markdown("### Generated Continuations:") | |
| with gr.Row(): | |
| output1 = gr.Textbox( | |
| label="Continuation 1", | |
| lines=8, | |
| max_lines=15, | |
| interactive=False | |
| ) | |
| output2 = gr.Textbox( | |
| label="Continuation 2", | |
| lines=8, | |
| max_lines=15, | |
| interactive=False | |
| ) | |
| output3 = gr.Textbox( | |
| label="Continuation 3", | |
| lines=8, | |
| max_lines=15, | |
| interactive=False | |
| ) | |
| return output1, output2, output3 | |
| def setup_event_handlers( | |
| user_input: gr.Textbox, | |
| model_choice: gr.Dropdown, | |
| length: gr.Slider, | |
| generate_btn: gr.Button, | |
| outputs: Tuple[gr.Textbox, gr.Textbox, gr.Textbox] | |
| ) -> None: | |
| """Setup event handlers for the interface""" | |
| inputs = [ | |
| user_input, | |
| model_choice, | |
| gr.Number(value=UI_CONFIG["max_sequences"], visible=False), | |
| length | |
| ] | |
| generate_btn.click( | |
| fn=generate_text, | |
| inputs=inputs, | |
| outputs=list(outputs) | |
| ) | |
| user_input.submit( | |
| fn=generate_text, | |
| inputs=inputs, | |
| outputs=list(outputs) | |
| ) | |
| def create_interface() -> gr.Blocks: | |
| """Create and return the Gradio interface""" | |
| with gr.Blocks(title=UI_CONFIG["title"], theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(f"# {UI_CONFIG['title']}") | |
| gr.Markdown(UI_CONFIG["description"]) | |
| with gr.Row(): | |
| user_input, model_choice, length, generate_btn = create_input_section() | |
| outputs = create_output_section() | |
| setup_event_handlers(user_input, model_choice, length, generate_btn, outputs) | |
| return demo | |
| def initialize_model_on_startup() -> None: | |
| """Initialize the default model on startup""" | |
| try: | |
| logger.info(f"Initializing {UI_CONFIG['default_model']} model on startup...") | |
| model_manager.get_trainer(UI_CONFIG["default_model"]) | |
| logger.info(f"{UI_CONFIG['default_model']} model initialized successfully!") | |
| except Exception as e: | |
| logger.warning(f"Could not initialize model on startup: {e}") | |
| logger.info("Model will be initialized when first used.") | |
| def main() -> None: | |
| """Main function to launch the Gradio app""" | |
| logger.info(f"Starting {UI_CONFIG['title']} application...") | |
| initialize_model_on_startup() | |
| demo = create_interface() | |
| logger.info(f"Launching interface on {UI_CONFIG['host']}:{UI_CONFIG['port']}") | |
| demo.launch( | |
| server_name=UI_CONFIG["host"], | |
| server_port=UI_CONFIG["port"], | |
| share=False, | |
| show_error=True | |
| ) | |
| if __name__ == "__main__": | |
| main() | |