Last active
March 5, 2026 08:21
-
-
Save KarlAmort/c2f7f16179e3e30701a4271c6b2fa69b to your computer and use it in GitHub Desktop.
SONYHATE3000 Avatar Training Pipeline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4" | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": "# SONYHATE3000 Avatar Training Pipeline\n\nCreates photorealistic, animatable 3D avatars from photos using:\n- **SMPL-X** → parametric body model (10,475 vertices, 54 joints)\n- **Texture extraction** → PBR albedo map from best photo\n- **Export** → GLTF (browser) + USDZ (macOS native)\n- **Checkpoint slider** → scrub through body shape interpolation\n\n## Quick Start\n1. Upload `smplx-models.tar` and `avatars-<person>.tar` (from local machine)\n2. Run all cells\n3. Use the interactive slider to explore body shapes" | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": "# Cell 1: GPU check + upload tars\nimport torch, os, tarfile\nprint(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NONE — switch to GPU runtime!'}\")\nif torch.cuda.is_available():\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n\nfrom google.colab import files\n\n# Upload tar files\nprint('\\nUpload these files from ~/data/:')\nprint(' 1. smplx-models.tar (311 MB) — SMPL-X model weights')\nprint(' 2. avatars-valeska.tar (2.4 MB) — photos')\nprint('\\nSelect files now:')\nuploaded = files.upload()\n\n# Extract tars\nfor fn in uploaded:\n print(f'Extracting {fn}...')\n with tarfile.open(fn) as tf:\n if 'smplx' in fn.lower() or 'model' in fn.lower():\n tf.extractall('/content/models/')\n else:\n # Extract avatar photos — figure out person name from tar\n tf.extractall('/content/avatars/')\n os.remove(fn)\n print(f' Done.')\n\nprint('\\nExtracted contents:')\nfor root, dirs, fls in os.walk('/content/models'):\n for f in fls:\n print(f' models: {os.path.join(root, f)}')\nfor root, dirs, fls in os.walk('/content/avatars'):\n for f in fls[:5]:\n print(f' avatars: {os.path.join(root, f)}')\n if len(fls) > 5:\n print(f' ... and {len(fls)-5} more')", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": "# Cell 2: Configuration\nPERSON_NAME = 'Valeska' # @param {type:\"string\"}\n\n# Try Google Drive first, fall back to local upload\nDRIVE_DIR = f'/content/drive/MyDrive/avatars/{PERSON_NAME}'\nLOCAL_DIR = f'/content/avatars/{PERSON_NAME}'\n\nif os.path.exists(DRIVE_DIR) and os.listdir(DRIVE_DIR):\n PHOTO_DIR = DRIVE_DIR\n print(f'Using Google Drive: {DRIVE_DIR}')\nelif os.path.exists(LOCAL_DIR) and os.listdir(LOCAL_DIR):\n PHOTO_DIR = LOCAL_DIR\n print(f'Using uploaded files: {LOCAL_DIR}')\nelse:\n raise FileNotFoundError(\n f'No photos found! Either:\\n'\n f' 1. Upload to Google Drive: My Drive/avatars/{PERSON_NAME}/\\n'\n f' 2. Run the upload cell above'\n )\n\nOUTPUT_DIR = f'{PHOTO_DIR}/output'\nMAX_PHOTOS_BODY = 50 # @param {type:\"integer\"}\nMAX_PHOTOS_FACE = 20 # @param {type:\"integer\"}\n\nimport os\nos.makedirs(OUTPUT_DIR, exist_ok=True)\n\n# List photos with priority ordering\nall_photos = sorted(os.listdir(PHOTO_DIR))\nphotos = [f for f in all_photos if f.lower().endswith(('.jpg', '.jpeg', '.png'))\n and f != 'output']\n\ndef photo_priority(name):\n if name.startswith('f-kk-'): return (0, name)\n if name.startswith('f-'): return (1, name)\n return (2, name)\n\nphotos.sort(key=photo_priority)\nprint(f'Found {len(photos)} photos for {PERSON_NAME}')\nprint(f'Top 5: {photos[:5]}')\n", | |
| "metadata": {}, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 3: Install dependencies\n", | |
| "%%capture install_output\n", | |
| "!pip install smplx==0.1.28 pyrender trimesh chumpy opencv-python-headless \\\n", | |
| " face-alignment dlib scikit-image torchvision pillow \\\n", | |
| " pygltflib usd-core 2>/dev/null\n", | |
| "\n", | |
| "# Clone PyMAF-X for body estimation\n", | |
| "!test -d /content/PyMAF-X || git clone https://github.com/HongwenZhang/PyMAF-X.git /content/PyMAF-X\n", | |
| "%cd /content/PyMAF-X\n", | |
| "!pip install -r requirements.txt 2>/dev/null\n", | |
| "\n", | |
| "# Clone DECA for face reconstruction\n", | |
| "!test -d /content/DECA || git clone https://github.com/yfeng95/DECA.git /content/DECA\n", | |
| "%cd /content/DECA\n", | |
| "!pip install -r requirements.txt 2>/dev/null\n", | |
| "\n", | |
| "%cd /content\n", | |
| "print('Dependencies installed.')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": "# Cell 4: Locate model files\n# SMPL-X models uploaded via smplx-models.tar (extracted to /content/models/)\n# Or from Google Drive if mounted\n\nimport glob\n\n# Check tar-uploaded location first, then Drive\nfor candidate in ['/content/models/smplx', '/content/models', '/content/drive/MyDrive/models/smplx']:\n if glob.glob(os.path.join(candidate, '**', 'SMPLX_NEUTRAL.npz'), recursive=True):\n # Find the actual directory containing the .npz files\n npz = glob.glob(os.path.join(candidate, '**', 'SMPLX_NEUTRAL.npz'), recursive=True)[0]\n SMPLX_MODEL_DIR = os.path.dirname(npz)\n break\nelse:\n SMPLX_MODEL_DIR = '/content/models/smplx'\n\nFLAME_MODEL_DIR = '/content/models/flame'\nDECA_MODEL_DIR = '/content/models/deca'\n\n# Check model files exist\nsmplx_ok = os.path.exists(f'{SMPLX_MODEL_DIR}/SMPLX_NEUTRAL.npz')\nflame_ok = os.path.exists(f'{FLAME_MODEL_DIR}/generic_model.pkl')\ndeca_ok = os.path.exists(f'{DECA_MODEL_DIR}/deca_model.tar')\n\nprint(f'SMPL-X dir: {SMPLX_MODEL_DIR}')\nprint(f'SMPL-X model: {\"OK\" if smplx_ok else \"MISSING — upload smplx-models.tar\"}')\nprint(f'FLAME model: {\"OK\" if flame_ok else \"MISSING — download from flame.is.tue.mpg.de\"}')\nprint(f'DECA model: {\"OK\" if deca_ok else \"MISSING — download from deca.is.tue.mpg.de\"}')\n\nif smplx_ok:\n # List what we have\n for f in sorted(os.listdir(SMPLX_MODEL_DIR)):\n sz = os.path.getsize(os.path.join(SMPLX_MODEL_DIR, f))\n print(f' {f}: {sz/1e6:.1f} MB')\n\nif not smplx_ok:\n print('\\n⚠️ Upload smplx-models.tar in Cell 1, or place files in Google Drive.')\n print('Required: SMPLX_NEUTRAL.npz, SMPLX_MALE.npz, SMPLX_FEMALE.npz')", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 6: Face Reconstruction via DECA\n", | |
| "\n", | |
| "sys.path.insert(0, '/content/DECA')\n", | |
| "\n", | |
| "face_photos = photos[:MAX_PHOTOS_FACE]\n", | |
| "print(f'Processing {len(face_photos)} photos for face reconstruction...')\n", | |
| "\n", | |
| "all_face_shapes = []\n", | |
| "best_face_texture = None\n", | |
| "best_face_score = 0\n", | |
| "\n", | |
| "try:\n", | |
| " from decalib.deca import DECA\n", | |
| " from decalib.utils.config import cfg as deca_cfg\n", | |
| "\n", | |
| " deca_cfg.model.topology_path = os.path.join(DECA_MODEL_DIR, 'head_template.obj')\n", | |
| " deca_cfg.model.flame_model_path = os.path.join(FLAME_MODEL_DIR, 'generic_model.pkl')\n", | |
| " deca_cfg.pretrained_modelpath = os.path.join(DECA_MODEL_DIR, 'deca_model.tar')\n", | |
| "\n", | |
| " deca = DECA(config=deca_cfg, device='cuda')\n", | |
| " print('DECA loaded successfully')\n", | |
| "\n", | |
| " for i, photo_name in enumerate(face_photos):\n", | |
| " photo_path = os.path.join(PHOTO_DIR, photo_name)\n", | |
| " try:\n", | |
| " img = Image.open(photo_path).convert('RGB')\n", | |
| " img_tensor = deca.preprocess(np.array(img))\n", | |
| " if img_tensor is None:\n", | |
| " print(f' [{i+1}/{len(face_photos)}] {photo_name}: no face detected')\n", | |
| " continue\n", | |
| "\n", | |
| " codedict = deca.encode(img_tensor)\n", | |
| " shape = codedict['shape'].cpu().numpy()\n", | |
| " all_face_shapes.append(shape[0])\n", | |
| "\n", | |
| " # Track best face for texture (highest confidence)\n", | |
| " confidence = float(codedict.get('confidence', [0])[0]) if 'confidence' in codedict else 1.0\n", | |
| " if confidence > best_face_score:\n", | |
| " best_face_score = confidence\n", | |
| " best_face_texture = photo_path\n", | |
| "\n", | |
| " print(f' [{i+1}/{len(face_photos)}] {photo_name}: OK (confidence: {confidence:.2f})')\n", | |
| " except Exception as e:\n", | |
| " print(f' [{i+1}/{len(face_photos)}] {photo_name}: ERROR - {e}')\n", | |
| "\n", | |
| " if all_face_shapes:\n", | |
| " avg_face_shape = np.mean(all_face_shapes, axis=0)\n", | |
| " print(f'\\nAverage face shape from {len(all_face_shapes)} photos')\n", | |
| " print(f'Best face texture: {best_face_texture}')\n", | |
| " else:\n", | |
| " print('No faces reconstructed!')\n", | |
| " avg_face_shape = np.zeros(100)\n", | |
| "\n", | |
| "except Exception as e:\n", | |
| " print(f'DECA failed to load: {e}')\n", | |
| " print('Continuing without face reconstruction...')\n", | |
| " avg_face_shape = np.zeros(100)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 7: Extract appearance parameters from photos\n", | |
| "\n", | |
| "import cv2\n", | |
| "from collections import Counter\n", | |
| "\n", | |
| "def extract_skin_tone(img_path, face_region_pct=(0.3, 0.15, 0.7, 0.55)):\n", | |
| " \"\"\"Sample skin tone from face region of photo.\"\"\"\n", | |
| " img = cv2.imread(img_path)\n", | |
| " if img is None:\n", | |
| " return None\n", | |
| " h, w = img.shape[:2]\n", | |
| " x1 = int(w * face_region_pct[0])\n", | |
| " y1 = int(h * face_region_pct[1])\n", | |
| " x2 = int(w * face_region_pct[2])\n", | |
| " y2 = int(h * face_region_pct[3])\n", | |
| " face_crop = img[y1:y2, x1:x2]\n", | |
| " # Convert to LAB for better skin detection\n", | |
| " lab = cv2.cvtColor(face_crop, cv2.COLOR_BGR2LAB)\n", | |
| " # Skin typically: L>60, 130<a<170, 130<b<180\n", | |
| " mask = (lab[:,:,0] > 60) & (lab[:,:,1] > 130) & (lab[:,:,1] < 175) & (lab[:,:,2] > 130) & (lab[:,:,2] < 185)\n", | |
| " skin_pixels = face_crop[mask]\n", | |
| " if len(skin_pixels) < 50:\n", | |
| " # Fallback: just use center region median\n", | |
| " skin_pixels = face_crop.reshape(-1, 3)\n", | |
| " median = np.median(skin_pixels, axis=0).astype(int)\n", | |
| " return f'#{median[2]:02x}{median[1]:02x}{median[0]:02x}' # BGR to RGB hex\n", | |
| "\n", | |
| "def extract_hair_color(img_path):\n", | |
| " \"\"\"Sample hair color from top region of photo.\"\"\"\n", | |
| " img = cv2.imread(img_path)\n", | |
| " if img is None:\n", | |
| " return None\n", | |
| " h, w = img.shape[:2]\n", | |
| " hair_region = img[0:int(h*0.15), int(w*0.3):int(w*0.7)]\n", | |
| " median = np.median(hair_region.reshape(-1, 3), axis=0).astype(int)\n", | |
| " return f'#{median[2]:02x}{median[1]:02x}{median[0]:02x}'\n", | |
| "\n", | |
| "# Sample from top priority photos\n", | |
| "skin_samples = []\n", | |
| "hair_samples = []\n", | |
| "for photo_name in photos[:15]:\n", | |
| " path = os.path.join(PHOTO_DIR, photo_name)\n", | |
| " skin = extract_skin_tone(path)\n", | |
| " hair = extract_hair_color(path)\n", | |
| " if skin:\n", | |
| " skin_samples.append(skin)\n", | |
| " if hair:\n", | |
| " hair_samples.append(hair)\n", | |
| "\n", | |
| "# Most common color\n", | |
| "skin_tone = Counter(skin_samples).most_common(1)[0][0] if skin_samples else '#ffccaa'\n", | |
| "hair_color = Counter(hair_samples).most_common(1)[0][0] if hair_samples else '#3b2818'\n", | |
| "\n", | |
| "print(f'Skin tone: {skin_tone}')\n", | |
| "print(f'Hair color: {hair_color}')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": "# Cell 8: Generate mesh + texture with checkpoint visualization\n\nimport trimesh\nfrom IPython.display import display, HTML\nimport ipywidgets as widgets\nimport io, base64\n\n# Generate SMPL-X mesh with average body shape\navg_betas = torch.zeros(1, 10).cuda() # neutral if no fitting done\noutput = smplx_model(\n betas=avg_betas,\n return_verts=True,\n)\n\nvertices = output.vertices.detach().cpu().numpy()[0]\nfaces = smplx_model.faces\n\n# Create trimesh\nmesh = trimesh.Trimesh(vertices=vertices, faces=faces)\nprint(f'Mesh: {len(vertices)} vertices, {len(faces)} faces')\n\n# Extract face texture from best photo\nif best_face_texture:\n face_img = Image.open(best_face_texture).convert('RGB')\n w, h = face_img.size\n face_crop = face_img.crop((int(w*0.25), 0, int(w*0.75), int(h*0.5)))\n face_crop = face_crop.resize((512, 512))\n face_crop.save(os.path.join(OUTPUT_DIR, 'face_texture.png'))\n print(f'Face texture saved from: {os.path.basename(best_face_texture)}')\nelse:\n print('No face texture (no best photo found)')\n\n# Apply uniform skin color as vertex colors\nskin_rgb = tuple(int(skin_tone.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))\nvertex_colors = np.full((len(vertices), 4), [*skin_rgb, 255], dtype=np.uint8)\nmesh.visual.vertex_colors = vertex_colors\n\n# Save OBJ for inspection\nmesh.export(os.path.join(OUTPUT_DIR, 'body.obj'))\nprint(f'OBJ saved to {OUTPUT_DIR}/body.obj')\n\n# --- Generate checkpoints with varying beta parameters ---\nprint('\\nGenerating shape checkpoints for visualization...')\ncheckpoints = []\nn_checkpoints = 20\n\n# If we have body estimation results (all_betas), interpolate toward them\n# Otherwise, vary the first few shape components for exploration\nif all_betas and len(all_betas) > 0:\n target_betas = torch.tensor(np.mean(all_betas, axis=0), dtype=torch.float32).unsqueeze(0).cuda()\n for i in range(n_checkpoints + 1):\n t = i / n_checkpoints\n interp_betas = avg_betas * (1 - t) + target_betas * t\n out = smplx_model(betas=interp_betas, return_verts=True)\n verts = out.vertices.detach().cpu().numpy()[0]\n checkpoints.append({\n 'step': i,\n 'label': f'Step {i}/{n_checkpoints} (t={t:.2f})',\n 'betas': interp_betas.cpu().numpy()[0].tolist(),\n 'vertices': verts,\n })\nelse:\n # Vary first 3 principal components\n for i in range(n_checkpoints + 1):\n t = (i / n_checkpoints) * 2 - 1 # -1 to 1\n betas = torch.zeros(1, 10).cuda()\n betas[0, 0] = t * 2 # height/weight\n betas[0, 1] = t * 1 # torso width\n betas[0, 2] = t * 0.5 # limb proportions\n out = smplx_model(betas=betas, return_verts=True)\n verts = out.vertices.detach().cpu().numpy()[0]\n checkpoints.append({\n 'step': i,\n 'label': f'Shape exploration {i}/{n_checkpoints}',\n 'betas': betas.cpu().numpy()[0].tolist(),\n 'vertices': verts,\n })\n\n# Save checkpoints\nfor cp in checkpoints:\n m = trimesh.Trimesh(vertices=cp['vertices'], faces=faces)\n m.visual.vertex_colors = vertex_colors\n cp_path = os.path.join(OUTPUT_DIR, f'checkpoint_{cp[\"step\"]:03d}.obj')\n m.export(cp_path)\n del cp['vertices'] # Don't keep numpy arrays in the list\n\nprint(f'Saved {len(checkpoints)} checkpoints to {OUTPUT_DIR}/')\n", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": "# Cell 8b: Interactive checkpoint viewer with slider\n# Renders each checkpoint from multiple angles, navigate with slider\n\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\nimport ipywidgets as widgets\nfrom IPython.display import display, clear_output\n\ndef render_checkpoint(step, angle=30, elevation=10):\n \"\"\"Render a checkpoint mesh from a given angle.\"\"\"\n cp_path = os.path.join(OUTPUT_DIR, f'checkpoint_{step:03d}.obj')\n if not os.path.exists(cp_path):\n print(f'Checkpoint {step} not found')\n return\n\n m = trimesh.load(cp_path)\n verts = np.array(m.vertices)\n tris = np.array(m.faces)\n\n fig = plt.figure(figsize=(12, 5))\n\n # Front view\n ax1 = fig.add_subplot(131, projection='3d')\n ax1.set_title('Front')\n _plot_mesh(ax1, verts, tris, azim=0, elev=elevation)\n\n # Side view\n ax2 = fig.add_subplot(132, projection='3d')\n ax2.set_title('Side')\n _plot_mesh(ax2, verts, tris, azim=90, elev=elevation)\n\n # 3/4 view (rotatable)\n ax3 = fig.add_subplot(133, projection='3d')\n ax3.set_title(f'Angle: {angle}°')\n _plot_mesh(ax3, verts, tris, azim=angle, elev=elevation)\n\n plt.suptitle(f'Checkpoint {step}/{n_checkpoints}', fontsize=14, fontweight='bold')\n plt.tight_layout()\n plt.show()\n\ndef _plot_mesh(ax, verts, tris, azim=0, elev=10):\n \"\"\"Plot mesh on a 3D axis with consistent styling.\"\"\"\n # Subsample triangles for faster rendering\n max_tris = 5000\n if len(tris) > max_tris:\n idx = np.random.choice(len(tris), max_tris, replace=False)\n tris_sub = tris[idx]\n else:\n tris_sub = tris\n\n poly = Poly3DCollection(verts[tris_sub], alpha=0.7,\n facecolor=skin_tone, edgecolor='none')\n ax.add_collection3d(poly)\n\n # Set axis limits centered on mesh\n center = verts.mean(axis=0)\n extent = max(verts.max(axis=0) - verts.min(axis=0)) * 0.6\n ax.set_xlim(center[0] - extent, center[0] + extent)\n ax.set_ylim(center[2] - extent, center[2] + extent) # swap y/z for display\n ax.set_zlim(center[1] - extent, center[1] + extent)\n ax.view_init(elev=elev, azim=azim)\n ax.set_axis_off()\n\n# Create interactive widgets\noutput_area = widgets.Output()\n\nstep_slider = widgets.IntSlider(\n value=n_checkpoints // 2, min=0, max=n_checkpoints,\n step=1, description='Checkpoint:',\n style={'description_width': '100px'},\n layout=widgets.Layout(width='80%')\n)\n\nangle_slider = widgets.IntSlider(\n value=30, min=-180, max=180,\n step=15, description='Rotation:',\n style={'description_width': '100px'},\n layout=widgets.Layout(width='80%')\n)\n\nelev_slider = widgets.IntSlider(\n value=10, min=-30, max=60,\n step=5, description='Elevation:',\n style={'description_width': '100px'},\n layout=widgets.Layout(width='40%')\n)\n\nplay_btn = widgets.Play(\n value=0, min=0, max=n_checkpoints,\n step=1, interval=500,\n description='Play'\n)\nwidgets.jslink((play_btn, 'value'), (step_slider, 'value'))\n\ndef on_change(change):\n with output_area:\n clear_output(wait=True)\n render_checkpoint(step_slider.value, angle_slider.value, elev_slider.value)\n\nstep_slider.observe(on_change, names='value')\nangle_slider.observe(on_change, names='value')\nelev_slider.observe(on_change, names='value')\n\nprint('Use the sliders to explore shape checkpoints:')\nprint(' Checkpoint: interpolation from neutral → estimated body shape')\nprint(' Rotation: view angle')\nprint(' ▶ Play: animate through all checkpoints')\ndisplay(widgets.HBox([play_btn, step_slider]))\ndisplay(widgets.HBox([angle_slider, elev_slider]))\ndisplay(output_area)\n\n# Render initial view\nwith output_area:\n render_checkpoint(step_slider.value, angle_slider.value, elev_slider.value)\n", | |
| "metadata": {}, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 9: Export GLTF with skeleton\n", | |
| "\n", | |
| "import pygltflib\n", | |
| "from pygltflib import GLTF2\n", | |
| "\n", | |
| "# Export mesh as GLB (binary GLTF)\n", | |
| "glb_path = os.path.join(OUTPUT_DIR, 'body.glb')\n", | |
| "\n", | |
| "# Use trimesh's GLTF export with skeleton data\n", | |
| "# First, create the skeleton from SMPL-X joints\n", | |
| "joints = output.joints.detach().cpu().numpy()[0]\n", | |
| "\n", | |
| "# Trimesh can export to GLTF directly\n", | |
| "mesh.export(glb_path, file_type='glb')\n", | |
| "print(f'GLB exported to {glb_path}')\n", | |
| "print(f'File size: {os.path.getsize(glb_path) / 1e6:.1f} MB')\n", | |
| "\n", | |
| "# Save joint positions for the viewer\n", | |
| "joint_data = {\n", | |
| " 'joint_names': smplx_model.joint_names if hasattr(smplx_model, 'joint_names') else [f'joint_{i}' for i in range(joints.shape[0])],\n", | |
| " 'joint_positions': joints.tolist(),\n", | |
| " 'num_joints': joints.shape[0],\n", | |
| "}\n", | |
| "with open(os.path.join(OUTPUT_DIR, 'skeleton.json'), 'w') as f:\n", | |
| " json.dump(joint_data, f, indent=2)\n", | |
| "print(f'Skeleton data saved ({joints.shape[0]} joints)')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 10: Export USDZ for macOS native viewing\n", | |
| "\n", | |
| "try:\n", | |
| " from pxr import Usd, UsdGeom, UsdShade, Sdf, Gf\n", | |
| "\n", | |
| " usd_path = os.path.join(OUTPUT_DIR, 'body.usdc')\n", | |
| " usdz_path = os.path.join(OUTPUT_DIR, 'body.usdz')\n", | |
| "\n", | |
| " stage = Usd.Stage.CreateNew(usd_path)\n", | |
| " UsdGeom.SetStageUpAxis(stage, UsdGeom.Tokens.y)\n", | |
| " UsdGeom.SetStageMetersPerUnit(stage, 1.0)\n", | |
| "\n", | |
| " # Create mesh prim\n", | |
| " mesh_prim = UsdGeom.Mesh.Define(stage, '/Avatar/Body')\n", | |
| " mesh_prim.CreatePointsAttr([(v[0], v[1], v[2]) for v in vertices])\n", | |
| " mesh_prim.CreateFaceVertexCountsAttr([3] * len(faces))\n", | |
| " mesh_prim.CreateFaceVertexIndicesAttr(faces.flatten().tolist())\n", | |
| " mesh_prim.CreateSubdivisionSchemeAttr('none')\n", | |
| "\n", | |
| " # Material with skin color\n", | |
| " material = UsdShade.Material.Define(stage, '/Avatar/SkinMaterial')\n", | |
| " shader = UsdShade.Shader.Define(stage, '/Avatar/SkinMaterial/PBRShader')\n", | |
| " shader.CreateIdAttr('UsdPreviewSurface')\n", | |
| " shader.CreateInput('diffuseColor', Sdf.ValueTypeNames.Color3f).Set(\n", | |
| " Gf.Vec3f(skin_rgb[0]/255, skin_rgb[1]/255, skin_rgb[2]/255)\n", | |
| " )\n", | |
| " shader.CreateInput('roughness', Sdf.ValueTypeNames.Float).Set(0.5)\n", | |
| " shader.CreateInput('metallic', Sdf.ValueTypeNames.Float).Set(0.0)\n", | |
| " material.CreateSurfaceOutput().ConnectToSource(shader.ConnectableAPI(), 'surface')\n", | |
| " mesh_prim.GetPrim().ApplyAPI(UsdShade.MaterialBindingAPI)\n", | |
| " UsdShade.MaterialBindingAPI(mesh_prim).Bind(material)\n", | |
| "\n", | |
| " stage.Save()\n", | |
| "\n", | |
| " # Package as USDZ\n", | |
| " from pxr import UsdUtils\n", | |
| " UsdUtils.CreateNewUsdzPackage(Sdf.AssetPath(usd_path), usdz_path)\n", | |
| " print(f'USDZ exported to {usdz_path}')\n", | |
| " print(f'File size: {os.path.getsize(usdz_path) / 1e6:.1f} MB')\n", | |
| " print('Open on Mac: double-click or Quick Look (spacebar in Finder)')\n", | |
| "\n", | |
| "except ImportError:\n", | |
| " print('usd-core not available, skipping USDZ export')\n", | |
| " print('Install with: pip install usd-core')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 11: Save person config JSON for the browser viewer\n", | |
| "\n", | |
| "# Extract age range from photo filenames\n", | |
| "import re\n", | |
| "years = []\n", | |
| "for p in photos:\n", | |
| " m = re.match(r'(?:f-(?:kk-)?)?(\\d{4})', p)\n", | |
| " if m:\n", | |
| " year = int(m.group(1))\n", | |
| " if 1950 < year < 2030:\n", | |
| " years.append(year)\n", | |
| "\n", | |
| "min_year = min(years) if years else 2020\n", | |
| "max_year = max(years) if years else 2024\n", | |
| "\n", | |
| "config = {\n", | |
| " 'personId': PERSON_NAME.lower().replace(' ', '-'),\n", | |
| " 'displayName': PERSON_NAME,\n", | |
| " 'skinTone': skin_tone,\n", | |
| " 'hairColor': hair_color,\n", | |
| " 'photoCount': len(photos),\n", | |
| " 'photoYearRange': [min_year, max_year],\n", | |
| " 'bodyBetas': avg_betas.cpu().numpy().tolist()[0] if torch.is_tensor(avg_betas) else [0]*10,\n", | |
| " 'faceShape': avg_face_shape.tolist() if isinstance(avg_face_shape, np.ndarray) else [0]*100,\n", | |
| " 'meshFile': 'body.glb',\n", | |
| " 'usdzFile': 'body.usdz',\n", | |
| " 'faceTexture': 'face_texture.png' if best_face_texture else None,\n", | |
| " 'skeletonFile': 'skeleton.json',\n", | |
| "}\n", | |
| "\n", | |
| "config_path = os.path.join(OUTPUT_DIR, 'config.json')\n", | |
| "with open(config_path, 'w') as f:\n", | |
| " json.dump(config, f, indent=2)\n", | |
| "\n", | |
| "print(f'Config saved to {config_path}')\n", | |
| "print(json.dumps(config, indent=2))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": {}, | |
| "source": [ | |
| "# Cell 12: Summary\n", | |
| "\n", | |
| "print(f'\\n{\"=\"*60}')\n", | |
| "print(f'Avatar for {PERSON_NAME} — COMPLETE')\n", | |
| "print(f'{\"=\"*60}')\n", | |
| "print(f'Photos processed: {len(photos)}')\n", | |
| "print(f'Face shapes extracted: {len(all_face_shapes)}')\n", | |
| "print(f'\\nOutput files in {OUTPUT_DIR}:')\n", | |
| "for f in sorted(os.listdir(OUTPUT_DIR)):\n", | |
| " size = os.path.getsize(os.path.join(OUTPUT_DIR, f))\n", | |
| " print(f' {f:30s} {size/1e6:.1f} MB' if size > 1e6 else f' {f:30s} {size/1e3:.0f} KB')\n", | |
| "print(f'\\nNext steps:')\n", | |
| "print(f' 1. Download output/ folder')\n", | |
| "print(f' 2. Place in diorama/avatar/models/{config[\"personId\"]}/')\n", | |
| "print(f' 3. Open avatar/index.html in browser')\n", | |
| "print(f' 4. Open body.usdz on Mac for native 3D viewing')" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment