$ cat node-template.py
Video LUT Uniform
// Color-matches a video to a reference image using Reinhard color transfer. Extracts the color profile from a reference frame, generates a 3D LUT, and applies it to the target video via ffmpeg. Strength parameter controls transfer intensity.
Process
Utility
template.py
1import os2import sys3import json4import subprocess5import traceback67import numpy as np89INPUT_DIR = "/data/input"10OUTPUT_DIR = "/data/output"11FFMPEG_IMAGE = "jrottenberg/ffmpeg:7-ubuntu"121314# ---------------------------------------------------------------------------15# Image loading helpers (ffmpeg-based, no PIL needed)16# ---------------------------------------------------------------------------1718def _get_image_dimensions(probe_cmd, host_input, host_output, input_path):19 """Run ffprobe in a sibling container and return (width, height)."""20 shell_script = (21 f"ffprobe -v error -select_streams v:0 "22 f"-show_entries stream=width,height -of csv=s=x:p=0 '{input_path}'"23 )24 cmd = [25 "docker", "run", "--rm",26 "--network", "none",27 "--memory", "512m",28 "--cpus", "0.5",29 "-v", f"{host_input}:/data/input:ro",30 "-v", f"{host_output}:/data/output:rw",31 "--entrypoint", "sh",32 FFMPEG_IMAGE,33 "-c", shell_script,34 ]35 result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)36 if result.returncode != 0:37 raise RuntimeError(f"ffprobe failed: {result.stderr[-1000:]}")38 dims = result.stdout.strip().split("x")39 return int(dims[0]), int(dims[1])404142def load_image_as_numpy(filename, host_input, host_output):43 """Load an image from INPUT_DIR as a (H, W, 3) uint8 numpy array via ffmpeg."""44 input_path = f"/data/input/{filename}"45 raw_file = "_ref_raw.rgb"46 raw_local = os.path.join(OUTPUT_DIR, raw_file)4748 # Get dimensions49 w, h = _get_image_dimensions("ffprobe", host_input, host_output, input_path)5051 # Decode to raw RGB2452 shell_script = (53 f"ffmpeg -i '{input_path}' -vframes 1 -f rawvideo -pix_fmt rgb24 "54 f"-y '/data/output/{raw_file}'"55 )56 cmd = [57 "docker", "run", "--rm",58 "--network", "none",59 "--memory", "512m",60 "--cpus", "0.5",61 "-v", f"{host_input}:/data/input:ro",62 "-v", f"{host_output}:/data/output:rw",63 "--entrypoint", "sh",64 FFMPEG_IMAGE,65 "-c", shell_script,66 ]67 result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)68 if result.returncode != 0:69 raise RuntimeError(f"ffmpeg raw decode failed: {result.stderr[-1000:]}")7071 expected_size = w * h * 372 actual_size = os.path.getsize(raw_local)73 if actual_size != expected_size:74 raise RuntimeError(75 f"Raw file size mismatch: expected {expected_size} bytes ({w}x{h}x3), got {actual_size}"76 )7778 arr = np.fromfile(raw_local, dtype=np.uint8).reshape(h, w, 3)79 os.remove(raw_local)80 return arr818283def extract_first_frame_as_numpy(video_filename, host_input, host_output):84 """Extract the first frame from a video as a (H, W, 3) uint8 numpy array."""85 input_path = f"/data/input/{video_filename}"86 raw_file = "_target_raw.rgb"87 raw_local = os.path.join(OUTPUT_DIR, raw_file)8889 # Get dimensions90 w, h = _get_image_dimensions("ffprobe", host_input, host_output, input_path)9192 # Extract first frame as raw RGB2493 shell_script = (94 f"ffmpeg -i '{input_path}' -vframes 1 -f rawvideo -pix_fmt rgb24 "95 f"-y '/data/output/{raw_file}'"96 )97 cmd = [98 "docker", "run", "--rm",99 "--network", "none",100 "--memory", "1g",101 "--cpus", "1.0",102 "-v", f"{host_input}:/data/input:ro",103 "-v", f"{host_output}:/data/output:rw",104 "--entrypoint", "sh",105 FFMPEG_IMAGE,106 "-c", shell_script,107 ]108 result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)109 if result.returncode != 0:110 raise RuntimeError(f"Frame extraction failed: {result.stderr[-2000:]}")111112 expected_size = w * h * 3113 actual_size = os.path.getsize(raw_local)114 if actual_size != expected_size:115 raise RuntimeError(116 f"Raw file size mismatch: expected {expected_size} bytes ({w}x{h}x3), got {actual_size}"117 )118119 arr = np.fromfile(raw_local, dtype=np.uint8).reshape(h, w, 3)120 os.remove(raw_local)121 return arr122123124# ---------------------------------------------------------------------------125# sRGB <-> CIE LAB conversion (no scipy needed)126# ---------------------------------------------------------------------------127128def srgb_to_linear(rgb):129 """Convert sRGB [0,1] to linear RGB."""130 rgb = np.clip(rgb, 0.0, 1.0)131 mask = rgb <= 0.04045132 out = np.where(mask, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4)133 return out134135136def linear_to_srgb(rgb):137 """Convert linear RGB to sRGB [0,1]."""138 rgb = np.clip(rgb, 0.0, 1.0)139 mask = rgb <= 0.0031308140 out = np.where(mask, rgb * 12.92, 1.055 * (rgb ** (1.0 / 2.4)) - 0.055)141 return np.clip(out, 0.0, 1.0)142143144def rgb_to_xyz(rgb_linear):145 """Linear RGB to XYZ D65."""146 # sRGB -> XYZ matrix (D65)147 M = np.array([148 [0.4124564, 0.3575761, 0.1804375],149 [0.2126729, 0.7151522, 0.0721750],150 [0.0193339, 0.1191920, 0.9503041],151 ])152 return rgb_linear @ M.T153154155def xyz_to_rgb(xyz):156 """XYZ D65 to linear RGB."""157 M_inv = np.array([158 [ 3.2404542, -1.5371385, -0.4985314],159 [-0.9692660, 1.8760108, 0.0415560],160 [ 0.0556434, -0.2040259, 1.0572252],161 ])162 return xyz @ M_inv.T163164165def xyz_to_lab(xyz):166 """XYZ to CIE LAB (D65 illuminant)."""167 # D65 reference white168 ref = np.array([0.95047, 1.00000, 1.08883])169 xyz_n = xyz / ref170171 delta = 6.0 / 29.0172 delta_sq = delta ** 2173 delta_cb = delta ** 3174175 mask = xyz_n > delta_cb176 f = np.where(mask, xyz_n ** (1.0 / 3.0), xyz_n / (3.0 * delta_sq) + 4.0 / 29.0)177178 L = 116.0 * f[..., 1] - 16.0179 a = 500.0 * (f[..., 0] - f[..., 1])180 b = 200.0 * (f[..., 1] - f[..., 2])181182 return np.stack([L, a, b], axis=-1)183184185def lab_to_xyz(lab):186 """CIE LAB to XYZ (D65 illuminant)."""187 ref = np.array([0.95047, 1.00000, 1.08883])188 delta = 6.0 / 29.0189190 L, a, b = lab[..., 0], lab[..., 1], lab[..., 2]191 fy = (L + 16.0) / 116.0192 fx = a / 500.0 + fy193 fz = fy - b / 200.0194195 f_vals = np.stack([fx, fy, fz], axis=-1)196 mask = f_vals > delta197 xyz_n = np.where(mask, f_vals ** 3, 3.0 * (delta ** 2) * (f_vals - 4.0 / 29.0))198199 return xyz_n * ref200201202def srgb_to_lab(srgb):203 """sRGB [0,1] -> LAB."""204 return xyz_to_lab(rgb_to_xyz(srgb_to_linear(srgb)))205206207def lab_to_srgb(lab):208 """LAB -> sRGB [0,1]."""209 return linear_to_srgb(xyz_to_rgb(lab_to_xyz(lab)))210211212# ---------------------------------------------------------------------------213# Reinhard color transfer214# ---------------------------------------------------------------------------215216def compute_lab_stats(image_array):217 """Compute mean and std per LAB channel for an image (H,W,3 uint8)."""218 pixels = image_array.reshape(-1, 3).astype(np.float64) / 255.0219 lab = srgb_to_lab(pixels)220 return lab.mean(axis=0), lab.std(axis=0)221222223def generate_lut_cube(ref_mean, ref_std, tgt_mean, tgt_std, strength, lut_size):224 """Generate a .cube 3D LUT encoding the Reinhard transfer."""225 n = lut_size226 steps = np.linspace(0.0, 1.0, n)227 # Build NxNxN grid of RGB values228 b, g, r = np.meshgrid(steps, steps, steps, indexing='ij')229 grid = np.stack([r, g, b], axis=-1).reshape(-1, 3)230231 # Convert grid to LAB232 lab = srgb_to_lab(grid)233234 # Reinhard transfer per channel with strength blending235 for ch in range(3):236 src_std = tgt_std[ch]237 if src_std < 1e-6:238 # Near-zero std: skip scaling to avoid division by zero239 shifted = lab[:, ch] - tgt_mean[ch] + ref_mean[ch]240 else:241 shifted = (lab[:, ch] - tgt_mean[ch]) * (ref_std[ch] / src_std) + ref_mean[ch]242 lab[:, ch] = lab[:, ch] * (1.0 - strength) + shifted * strength243244 # Convert back to sRGB245 rgb_out = lab_to_srgb(lab)246 rgb_out = np.clip(rgb_out, 0.0, 1.0)247248 # Write .cube file249 lines = [f"LUT_3D_SIZE {n}", ""]250 for i in range(rgb_out.shape[0]):251 lines.append(f"{rgb_out[i, 0]:.6f} {rgb_out[i, 1]:.6f} {rgb_out[i, 2]:.6f}")252253 return "\n".join(lines)254255256# ---------------------------------------------------------------------------257# Main258# ---------------------------------------------------------------------------259260def main():261 execution_input = json.loads(sys.stdin.read())262 inputs = execution_input.get("inputs", {})263264 video = inputs.get("video", "")265 reference = inputs.get("reference", "")266 if not video:267 raise ValueError("Video input is required")268 if not reference:269 raise ValueError("Reference image input is required")270271 strength = float(inputs.get("strength", 1.0))272 lut_size = int(inputs.get("lut_size", 33))273274 # Validate lut_size275 if lut_size not in (17, 33, 65):276 lut_size = 33277278 os.makedirs(OUTPUT_DIR, exist_ok=True)279280 host_input = os.environ.get("HOST_STAGING_INPUT", INPUT_DIR)281 host_output = os.environ.get("HOST_STAGING_OUTPUT", OUTPUT_DIR)282283 # Stage 1: Load reference image and extract first frame as numpy arrays284 print("Loading reference image...", file=sys.stderr)285 ref_array = load_image_as_numpy(reference, host_input, host_output)286287 print("Extracting first frame from target video...", file=sys.stderr)288 tgt_array = extract_first_frame_as_numpy(video, host_input, host_output)289290 # Compute LAB statistics291 print("Computing color statistics...", file=sys.stderr)292 ref_mean, ref_std = compute_lab_stats(ref_array)293 tgt_mean, tgt_std = compute_lab_stats(tgt_array)294295 # Generate 3D LUT296 print(f"Generating {lut_size}x{lut_size}x{lut_size} 3D LUT (strength={strength})...", file=sys.stderr)297 cube_content = generate_lut_cube(ref_mean, ref_std, tgt_mean, tgt_std, strength, lut_size)298299 lut_path = os.path.join(OUTPUT_DIR, "_transfer.cube")300 with open(lut_path, "w") as f:301 f.write(cube_content)302303 # Stage 2: Apply LUT to video via ffmpeg304 print("Applying LUT to video...", file=sys.stderr)305 src = f"/data/input/{video}"306 lut = "/data/output/_transfer.cube"307 out = "/data/output/color_matched.mp4"308309 shell_script = "\n".join([310 "set -e",311 "",312 f"HAS_AUDIO=$(ffprobe -v quiet -select_streams a -show_entries stream=codec_type -of csv=p=0 '{src}' | head -1)",313 'if [ -n "$HAS_AUDIO" ]; then',314 f" ffmpeg -i '{src}' -vf 'lut3d={lut}' "315 f"-c:v libx264 -preset medium -crf 23 -pix_fmt yuv420p "316 f"-c:a aac -ar 44100 -ac 2 -b:a 192k "317 f"-movflags +faststart -y '{out}'",318 "else",319 f" ffmpeg -i '{src}' -f lavfi -i anullsrc=r=44100:cl=stereo "320 f"-vf 'lut3d={lut}' "321 f"-c:v libx264 -preset medium -crf 23 -pix_fmt yuv420p "322 f"-c:a aac -ar 44100 -ac 2 -b:a 192k "323 f"-map 0:v:0 -map 1:a:0 -shortest "324 f"-movflags +faststart -y '{out}'",325 "fi",326 ])327328 cmd = [329 "docker", "run", "--rm",330 "--network", "none",331 "--memory", "2g",332 "--cpus", "2.0",333 "-v", f"{host_input}:/data/input:ro",334 "-v", f"{host_output}:/data/output:rw",335 "--entrypoint", "sh",336 FFMPEG_IMAGE,337 "-c", shell_script,338 ]339340 result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)341 if result.returncode != 0:342 raise RuntimeError(f"ffmpeg LUT application failed (exit {result.returncode}): {result.stderr[-2000:]}")343344 # Clean up temp files345 for tmp in ["_transfer.cube"]:346 tmp_path = os.path.join(OUTPUT_DIR, tmp)347 if os.path.exists(tmp_path):348 os.remove(tmp_path)349350 print(json.dumps({"video": "color_matched.mp4"}, indent=2))351352353if __name__ == "__main__":354 try:355 main()356 except Exception as e:357 print(json.dumps({358 "error": str(e),359 "errorType": type(e).__name__,360 "traceback": traceback.format_exc(),361 }), file=sys.stderr)362 sys.exit(1)