$ cat node-template.py

Image Edit

// Edits images using one or two reference images and a text prompt via a native GPU service. Uses Flux 2 Klein 9B FP8 with reference-latent conditioning. Supports configurable aspect ratio. Outputs the edited image as PNG.

Process
Image
template.py
1import os2import sys3import json4import math5import random6import subprocess7import time8import traceback910try:11    import requests12except ImportError:13    subprocess.check_call([sys.executable, "-m", "pip", "install", "requests"])14    import requests1516NATIVE_IMAGE_EDIT_SERVICE_URL = os.getenv(17    "NATIVE_IMAGE_EDIT_SERVICE_URL", "http://native-image-edit-service:8101"18)19_EMBLEMA_VERSION = os.getenv("EMBLEMA_VERSION", "dev")20NATIVE_IMAGE_EDIT_SERVICE_IMAGE = os.getenv(21    "NATIVE_IMAGE_EDIT_SERVICE_IMAGE",22    f"emblema/native-image-edit-service:{_EMBLEMA_VERSION}",23)24HF_CACHE_HOST_PATH = os.getenv("HF_CACHE_HOST_PATH", "/root/.cache/huggingface")25CONTAINER_NAME = "native-image-edit-service"26INPUT_DIR = "/data/input"27OUTPUT_DIR = "/data/output"282930def compute_dimensions(aspect_ratio: float, megapixel: float) -> tuple:31    """Compute width and height from aspect ratio and megapixel, divisible by 64."""32    total_pixels = megapixel * 1_000_00033    height = int(math.sqrt(total_pixels / aspect_ratio))34    width = int(height * aspect_ratio)35    width = max(64, (width // 64) * 64)36    height = max(64, (height // 64) * 64)37    return width, height383940def start_container():41    """Create and start native-image-edit-service, removing any stale container first."""42    subprocess.run(43        ["docker", "rm", "-f", CONTAINER_NAME],44        capture_output=True, text=True45    )4647    hf_token = os.getenv("HUGGINGFACE_TOKEN", "")48    print(f"Creating container {CONTAINER_NAME}...", file=sys.stderr)49    run_cmd = [50        "docker", "run", "-d",51        "--name", CONTAINER_NAME,52        "--network", "emblema",53        "--gpus", "all",54        "-e", "PORT=8101",55        "-e", "DEVICE=cuda",56        "-e", f"HF_TOKEN={hf_token}",57        "-v", f"{HF_CACHE_HOST_PATH}:/root/.cache/huggingface",58        NATIVE_IMAGE_EDIT_SERVICE_IMAGE,59    ]60    result = subprocess.run(run_cmd, capture_output=True, text=True)61    if result.returncode != 0:62        print(f"docker run failed (exit {result.returncode}): {result.stderr}", file=sys.stderr)63        raise RuntimeError(f"Failed to start container: {result.stderr}")6465    # Poll health endpoint66    timeout = 18067    interval = 368    elapsed = 069    health_url = f"{NATIVE_IMAGE_EDIT_SERVICE_URL}/health"70    while elapsed < timeout:71        try:72            r = requests.get(health_url, timeout=5)73            if r.status_code == 200:74                print(f"Container healthy (waited {elapsed}s).", file=sys.stderr)75                return76        except requests.ConnectionError:77            pass78        time.sleep(interval)79        elapsed += interval8081    raise RuntimeError(f"Container did not become healthy within {timeout}s")828384def stop_container():85    """Remove the container."""86    try:87        subprocess.run(88            ["docker", "rm", "-f", CONTAINER_NAME],89            capture_output=True, text=True, timeout=3090        )91        print(f"Container {CONTAINER_NAME} removed.", file=sys.stderr)92    except Exception as e:93        print(f"Warning: failed to remove container: {e}", file=sys.stderr)949596def main():97    try:98        input_json = sys.stdin.read()99        execution_input = json.loads(input_json)100        inputs = execution_input.get("inputs", {})101102        images = inputs.get("images", [])103        prompt = inputs.get("prompt", "")104        aspect_ratio = float(inputs.get("aspect_ratio", 1.667))105        megapixel = float(inputs.get("megapixel", 1.0))106        num_inference_steps = int(inputs.get("num_inference_steps", 4))107        guidance_scale = float(inputs.get("guidance_scale", 4.0))108109        # Seed mode handling110        seed_mode = inputs.get("seed_mode", "random")111        seed_input = int(inputs.get("seed", -1))112113        if seed_mode == "fixed" and seed_input >= 0:114            seed_value = seed_input115        else:116            seed_value = random.randint(0, 2**31 - 1)117118        # Normalize images to a list (single edge gives a string, array edge gives a list)119        if isinstance(images, str):120            images = [images]121122        if not prompt:123            raise ValueError("Prompt is required")124        if not (0.25 <= aspect_ratio <= 4.0):125            raise ValueError(f"Aspect ratio must be between 0.25 and 4.0, got {aspect_ratio}")126        if not (0.25 <= megapixel <= 4.0):127            raise ValueError(f"Megapixel must be between 0.25 and 4.0, got {megapixel}")128        if not images or len(images) == 0:129            raise ValueError("At least one input image is required")130        if len(images) > 2:131            raise ValueError("Maximum of 2 input images supported")132133        # Validate all input images exist134        image_paths = []135        for img_filename in images:136            local_path = os.path.join(INPUT_DIR, img_filename)137            if not os.path.exists(local_path):138                raise FileNotFoundError(f"Input image not found: {local_path}")139            image_paths.append(local_path)140141        os.makedirs(OUTPUT_DIR, exist_ok=True)142143        # Start the container144        start_container()145146        try:147            # Build multipart form data148            files = []149            for local_path in image_paths:150                f = open(local_path, "rb")151                files.append(("images", (os.path.basename(local_path), f, "image/png")))152153            width, height = compute_dimensions(aspect_ratio, megapixel)154155            data = {156                "prompt": prompt,157                "aspect_ratio": str(aspect_ratio),158                "megapixel": str(megapixel),159                "num_inference_steps": str(num_inference_steps),160                "guidance_scale": str(guidance_scale),161                "seed": str(seed_value),162            }163164            print(165                f"Requesting edit: images={len(image_paths)}, {width}x{height}, "166                f"steps={num_inference_steps}, cfg={guidance_scale}, seed={seed_value}",167                file=sys.stderr,168            )169170            resp = requests.post(171                f"{NATIVE_IMAGE_EDIT_SERVICE_URL}/edit",172                files=files,173                data=data,174                timeout=600,175            )176177            # Close file handles178            for _, (_, fh, _) in files:179                fh.close()180181            if resp.status_code != 200:182                try:183                    error_detail = resp.json()184                except Exception:185                    error_detail = resp.text186                raise RuntimeError(187                    f"Native image edit service returned {resp.status_code}: {error_detail}"188                )189190            # Save result191            out_filename = "edited_image.png"192            out_path = os.path.join(OUTPUT_DIR, out_filename)193            with open(out_path, "wb") as f:194                f.write(resp.content)195196            seed_used = resp.headers.get("X-Seed", str(seed_value))197            inference_time = resp.headers.get("X-Inference-Time-Ms", "unknown")198            output_size = resp.headers.get("X-Output-Size", "unknown")199            print(200                f"Edited: seed={seed_used}, time={inference_time}ms, size={output_size}, images={len(image_paths)}",201                file=sys.stderr,202            )203204            # Flat output — keys match OUTPUT_SCHEMA205            output = {206                "image": out_filename,207            }208            print(json.dumps(output, indent=2))209210        finally:211            stop_container()212213    except Exception as e:214        error_output = {215            "error": str(e),216            "errorType": type(e).__name__,217            "traceback": traceback.format_exc(),218        }219        print(json.dumps(error_output), file=sys.stderr)220        sys.exit(1)221222223if __name__ == "__main__":224    main()