File size: 665 Bytes
2d96ef6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# handler.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# 1) Charge en 8-bits sur GPU/CPU
model = AutoModelForCausalLM.from_pretrained(
    ".",                      # chemin du repo
    device_map="auto",
    load_in_8bit=True,
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(".", trust_remote_code=True)

# 2) Initialise la pipeline
text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)

def handle(inputs):
    # inputs peut être {"inputs": "Mon prompt"} ou juste une string
    prompt = inputs.get("inputs") if isinstance(inputs, dict) else inputs
    return text_gen(prompt)