$ cat node-template.py

Video LUT Uniform

// Color-matches a video to a reference image using Reinhard color transfer. Extracts the color profile from a reference frame, generates a 3D LUT, and applies it to the target video via ffmpeg. Strength parameter controls transfer intensity.

Process
Utility
template.py
1import os2import sys3import json4import subprocess5import traceback67import numpy as np89INPUT_DIR = "/data/input"10OUTPUT_DIR = "/data/output"11FFMPEG_IMAGE = "jrottenberg/ffmpeg:7-ubuntu"121314# ---------------------------------------------------------------------------15# Image loading helpers (ffmpeg-based, no PIL needed)16# ---------------------------------------------------------------------------1718def _get_image_dimensions(probe_cmd, host_input, host_output, input_path):19    """Run ffprobe in a sibling container and return (width, height)."""20    shell_script = (21        f"ffprobe -v error -select_streams v:0 "22        f"-show_entries stream=width,height -of csv=s=x:p=0 '{input_path}'"23    )24    cmd = [25        "docker", "run", "--rm",26        "--network", "none",27        "--memory", "512m",28        "--cpus", "0.5",29        "-v", f"{host_input}:/data/input:ro",30        "-v", f"{host_output}:/data/output:rw",31        "--entrypoint", "sh",32        FFMPEG_IMAGE,33        "-c", shell_script,34    ]35    result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)36    if result.returncode != 0:37        raise RuntimeError(f"ffprobe failed: {result.stderr[-1000:]}")38    dims = result.stdout.strip().split("x")39    return int(dims[0]), int(dims[1])404142def load_image_as_numpy(filename, host_input, host_output):43    """Load an image from INPUT_DIR as a (H, W, 3) uint8 numpy array via ffmpeg."""44    input_path = f"/data/input/{filename}"45    raw_file = "_ref_raw.rgb"46    raw_local = os.path.join(OUTPUT_DIR, raw_file)4748    # Get dimensions49    w, h = _get_image_dimensions("ffprobe", host_input, host_output, input_path)5051    # Decode to raw RGB2452    shell_script = (53        f"ffmpeg -i '{input_path}' -vframes 1 -f rawvideo -pix_fmt rgb24 "54        f"-y '/data/output/{raw_file}'"55    )56    cmd = [57        "docker", "run", "--rm",58        "--network", "none",59        "--memory", "512m",60        "--cpus", "0.5",61        "-v", f"{host_input}:/data/input:ro",62        "-v", f"{host_output}:/data/output:rw",63        "--entrypoint", "sh",64        FFMPEG_IMAGE,65        "-c", shell_script,66    ]67    result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)68    if result.returncode != 0:69        raise RuntimeError(f"ffmpeg raw decode failed: {result.stderr[-1000:]}")7071    expected_size = w * h * 372    actual_size = os.path.getsize(raw_local)73    if actual_size != expected_size:74        raise RuntimeError(75            f"Raw file size mismatch: expected {expected_size} bytes ({w}x{h}x3), got {actual_size}"76        )7778    arr = np.fromfile(raw_local, dtype=np.uint8).reshape(h, w, 3)79    os.remove(raw_local)80    return arr818283def extract_first_frame_as_numpy(video_filename, host_input, host_output):84    """Extract the first frame from a video as a (H, W, 3) uint8 numpy array."""85    input_path = f"/data/input/{video_filename}"86    raw_file = "_target_raw.rgb"87    raw_local = os.path.join(OUTPUT_DIR, raw_file)8889    # Get dimensions90    w, h = _get_image_dimensions("ffprobe", host_input, host_output, input_path)9192    # Extract first frame as raw RGB2493    shell_script = (94        f"ffmpeg -i '{input_path}' -vframes 1 -f rawvideo -pix_fmt rgb24 "95        f"-y '/data/output/{raw_file}'"96    )97    cmd = [98        "docker", "run", "--rm",99        "--network", "none",100        "--memory", "1g",101        "--cpus", "1.0",102        "-v", f"{host_input}:/data/input:ro",103        "-v", f"{host_output}:/data/output:rw",104        "--entrypoint", "sh",105        FFMPEG_IMAGE,106        "-c", shell_script,107    ]108    result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)109    if result.returncode != 0:110        raise RuntimeError(f"Frame extraction failed: {result.stderr[-2000:]}")111112    expected_size = w * h * 3113    actual_size = os.path.getsize(raw_local)114    if actual_size != expected_size:115        raise RuntimeError(116            f"Raw file size mismatch: expected {expected_size} bytes ({w}x{h}x3), got {actual_size}"117        )118119    arr = np.fromfile(raw_local, dtype=np.uint8).reshape(h, w, 3)120    os.remove(raw_local)121    return arr122123124# ---------------------------------------------------------------------------125# sRGB <-> CIE LAB conversion (no scipy needed)126# ---------------------------------------------------------------------------127128def srgb_to_linear(rgb):129    """Convert sRGB [0,1] to linear RGB."""130    rgb = np.clip(rgb, 0.0, 1.0)131    mask = rgb <= 0.04045132    out = np.where(mask, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4)133    return out134135136def linear_to_srgb(rgb):137    """Convert linear RGB to sRGB [0,1]."""138    rgb = np.clip(rgb, 0.0, 1.0)139    mask = rgb <= 0.0031308140    out = np.where(mask, rgb * 12.92, 1.055 * (rgb ** (1.0 / 2.4)) - 0.055)141    return np.clip(out, 0.0, 1.0)142143144def rgb_to_xyz(rgb_linear):145    """Linear RGB to XYZ D65."""146    # sRGB -> XYZ matrix (D65)147    M = np.array([148        [0.4124564, 0.3575761, 0.1804375],149        [0.2126729, 0.7151522, 0.0721750],150        [0.0193339, 0.1191920, 0.9503041],151    ])152    return rgb_linear @ M.T153154155def xyz_to_rgb(xyz):156    """XYZ D65 to linear RGB."""157    M_inv = np.array([158        [ 3.2404542, -1.5371385, -0.4985314],159        [-0.9692660,  1.8760108,  0.0415560],160        [ 0.0556434, -0.2040259,  1.0572252],161    ])162    return xyz @ M_inv.T163164165def xyz_to_lab(xyz):166    """XYZ to CIE LAB (D65 illuminant)."""167    # D65 reference white168    ref = np.array([0.95047, 1.00000, 1.08883])169    xyz_n = xyz / ref170171    delta = 6.0 / 29.0172    delta_sq = delta ** 2173    delta_cb = delta ** 3174175    mask = xyz_n > delta_cb176    f = np.where(mask, xyz_n ** (1.0 / 3.0), xyz_n / (3.0 * delta_sq) + 4.0 / 29.0)177178    L = 116.0 * f[..., 1] - 16.0179    a = 500.0 * (f[..., 0] - f[..., 1])180    b = 200.0 * (f[..., 1] - f[..., 2])181182    return np.stack([L, a, b], axis=-1)183184185def lab_to_xyz(lab):186    """CIE LAB to XYZ (D65 illuminant)."""187    ref = np.array([0.95047, 1.00000, 1.08883])188    delta = 6.0 / 29.0189190    L, a, b = lab[..., 0], lab[..., 1], lab[..., 2]191    fy = (L + 16.0) / 116.0192    fx = a / 500.0 + fy193    fz = fy - b / 200.0194195    f_vals = np.stack([fx, fy, fz], axis=-1)196    mask = f_vals > delta197    xyz_n = np.where(mask, f_vals ** 3, 3.0 * (delta ** 2) * (f_vals - 4.0 / 29.0))198199    return xyz_n * ref200201202def srgb_to_lab(srgb):203    """sRGB [0,1] -> LAB."""204    return xyz_to_lab(rgb_to_xyz(srgb_to_linear(srgb)))205206207def lab_to_srgb(lab):208    """LAB -> sRGB [0,1]."""209    return linear_to_srgb(xyz_to_rgb(lab_to_xyz(lab)))210211212# ---------------------------------------------------------------------------213# Reinhard color transfer214# ---------------------------------------------------------------------------215216def compute_lab_stats(image_array):217    """Compute mean and std per LAB channel for an image (H,W,3 uint8)."""218    pixels = image_array.reshape(-1, 3).astype(np.float64) / 255.0219    lab = srgb_to_lab(pixels)220    return lab.mean(axis=0), lab.std(axis=0)221222223def generate_lut_cube(ref_mean, ref_std, tgt_mean, tgt_std, strength, lut_size):224    """Generate a .cube 3D LUT encoding the Reinhard transfer."""225    n = lut_size226    steps = np.linspace(0.0, 1.0, n)227    # Build NxNxN grid of RGB values228    b, g, r = np.meshgrid(steps, steps, steps, indexing='ij')229    grid = np.stack([r, g, b], axis=-1).reshape(-1, 3)230231    # Convert grid to LAB232    lab = srgb_to_lab(grid)233234    # Reinhard transfer per channel with strength blending235    for ch in range(3):236        src_std = tgt_std[ch]237        if src_std < 1e-6:238            # Near-zero std: skip scaling to avoid division by zero239            shifted = lab[:, ch] - tgt_mean[ch] + ref_mean[ch]240        else:241            shifted = (lab[:, ch] - tgt_mean[ch]) * (ref_std[ch] / src_std) + ref_mean[ch]242        lab[:, ch] = lab[:, ch] * (1.0 - strength) + shifted * strength243244    # Convert back to sRGB245    rgb_out = lab_to_srgb(lab)246    rgb_out = np.clip(rgb_out, 0.0, 1.0)247248    # Write .cube file249    lines = [f"LUT_3D_SIZE {n}", ""]250    for i in range(rgb_out.shape[0]):251        lines.append(f"{rgb_out[i, 0]:.6f} {rgb_out[i, 1]:.6f} {rgb_out[i, 2]:.6f}")252253    return "\n".join(lines)254255256# ---------------------------------------------------------------------------257# Main258# ---------------------------------------------------------------------------259260def main():261    execution_input = json.loads(sys.stdin.read())262    inputs = execution_input.get("inputs", {})263264    video = inputs.get("video", "")265    reference = inputs.get("reference", "")266    if not video:267        raise ValueError("Video input is required")268    if not reference:269        raise ValueError("Reference image input is required")270271    strength = float(inputs.get("strength", 1.0))272    lut_size = int(inputs.get("lut_size", 33))273274    # Validate lut_size275    if lut_size not in (17, 33, 65):276        lut_size = 33277278    os.makedirs(OUTPUT_DIR, exist_ok=True)279280    host_input = os.environ.get("HOST_STAGING_INPUT", INPUT_DIR)281    host_output = os.environ.get("HOST_STAGING_OUTPUT", OUTPUT_DIR)282283    # Stage 1: Load reference image and extract first frame as numpy arrays284    print("Loading reference image...", file=sys.stderr)285    ref_array = load_image_as_numpy(reference, host_input, host_output)286287    print("Extracting first frame from target video...", file=sys.stderr)288    tgt_array = extract_first_frame_as_numpy(video, host_input, host_output)289290    # Compute LAB statistics291    print("Computing color statistics...", file=sys.stderr)292    ref_mean, ref_std = compute_lab_stats(ref_array)293    tgt_mean, tgt_std = compute_lab_stats(tgt_array)294295    # Generate 3D LUT296    print(f"Generating {lut_size}x{lut_size}x{lut_size} 3D LUT (strength={strength})...", file=sys.stderr)297    cube_content = generate_lut_cube(ref_mean, ref_std, tgt_mean, tgt_std, strength, lut_size)298299    lut_path = os.path.join(OUTPUT_DIR, "_transfer.cube")300    with open(lut_path, "w") as f:301        f.write(cube_content)302303    # Stage 2: Apply LUT to video via ffmpeg304    print("Applying LUT to video...", file=sys.stderr)305    src = f"/data/input/{video}"306    lut = "/data/output/_transfer.cube"307    out = "/data/output/color_matched.mp4"308309    shell_script = "\n".join([310        "set -e",311        "",312        f"HAS_AUDIO=$(ffprobe -v quiet -select_streams a -show_entries stream=codec_type -of csv=p=0 '{src}' | head -1)",313        'if [ -n "$HAS_AUDIO" ]; then',314        f"  ffmpeg -i '{src}' -vf 'lut3d={lut}' "315        f"-c:v libx264 -preset medium -crf 23 -pix_fmt yuv420p "316        f"-c:a aac -ar 44100 -ac 2 -b:a 192k "317        f"-movflags +faststart -y '{out}'",318        "else",319        f"  ffmpeg -i '{src}' -f lavfi -i anullsrc=r=44100:cl=stereo "320        f"-vf 'lut3d={lut}' "321        f"-c:v libx264 -preset medium -crf 23 -pix_fmt yuv420p "322        f"-c:a aac -ar 44100 -ac 2 -b:a 192k "323        f"-map 0:v:0 -map 1:a:0 -shortest "324        f"-movflags +faststart -y '{out}'",325        "fi",326    ])327328    cmd = [329        "docker", "run", "--rm",330        "--network", "none",331        "--memory", "2g",332        "--cpus", "2.0",333        "-v", f"{host_input}:/data/input:ro",334        "-v", f"{host_output}:/data/output:rw",335        "--entrypoint", "sh",336        FFMPEG_IMAGE,337        "-c", shell_script,338    ]339340    result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)341    if result.returncode != 0:342        raise RuntimeError(f"ffmpeg LUT application failed (exit {result.returncode}): {result.stderr[-2000:]}")343344    # Clean up temp files345    for tmp in ["_transfer.cube"]:346        tmp_path = os.path.join(OUTPUT_DIR, tmp)347        if os.path.exists(tmp_path):348            os.remove(tmp_path)349350    print(json.dumps({"video": "color_matched.mp4"}, indent=2))351352353if __name__ == "__main__":354    try:355        main()356    except Exception as e:357        print(json.dumps({358            "error": str(e),359            "errorType": type(e).__name__,360            "traceback": traceback.format_exc(),361        }), file=sys.stderr)362        sys.exit(1)