from typing import Dict, Any from transformers import JanusForConditionalGeneration, JanusProcessor import torch, base64, io, PIL.Image as Image class EndpointHandler: """ Works for: • text → text chat completions • text → image generation (pass {"generation_mode":"image"}) """ def __init__(self, model_path: str): self.processor = JanusProcessor.from_pretrained( model_path, trust_remote_code=True ) self.model = JanusForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.bfloat16, # fp16 also fine device_map="auto", load_in_4bit=True # comment out on bigger GPUs ) # ---- each request lands here ---- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: prompt = data.get("prompt") or data.get("inputs") gen_mode = data.get("generation_mode", "text") # "text" | "image" templ = self.processor.apply_chat_template( [{"role": "user", "content": [{"type": "text", "text": prompt}]}], add_generation_prompt=True, ) inputs = self.processor( text=templ, generation_mode=gen_mode, return_tensors="pt" ).to(self.model.device) out = self.model.generate( **inputs, generation_mode=gen_mode, max_new_tokens=data.get("max_new_tokens", 128) ) if gen_mode == "image": img = self.processor.decode(out[0], output_type="pil") return {"images": [self._pil_to_base64(img)]} else: return {"generated_text": self.processor.decode(out[0], skip_special_tokens=True)} @staticmethod def _pil_to_base64(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode()