| # handler.py | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # 1) Charge en 8-bits sur GPU/CPU | |
| model = AutoModelForCausalLM.from_pretrained( | |
| ".", # chemin du repo | |
| device_map="auto", | |
| load_in_8bit=True, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(".", trust_remote_code=True) | |
| # 2) Initialise la pipeline | |
| text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| def handle(inputs): | |
| # inputs peut être {"inputs": "Mon prompt"} ou juste une string | |
| prompt = inputs.get("inputs") if isinstance(inputs, dict) else inputs | |
| return text_gen(prompt) |