File size: 665 Bytes
2d96ef6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# handler.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# 1) Charge en 8-bits sur GPU/CPU
model = AutoModelForCausalLM.from_pretrained(
".", # chemin du repo
device_map="auto",
load_in_8bit=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(".", trust_remote_code=True)
# 2) Initialise la pipeline
text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
def handle(inputs):
# inputs peut être {"inputs": "Mon prompt"} ou juste une string
prompt = inputs.get("inputs") if isinstance(inputs, dict) else inputs
return text_gen(prompt) |