Skip to content

Instantly share code, notes, and snippets.

@KarlAmort
Last active March 5, 2026 08:21
Show Gist options
  • Select an option

  • Save KarlAmort/c2f7f16179e3e30701a4271c6b2fa69b to your computer and use it in GitHub Desktop.

Select an option

Save KarlAmort/c2f7f16179e3e30701a4271c6b2fa69b to your computer and use it in GitHub Desktop.
SONYHATE3000 Avatar Training Pipeline
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# SONYHATE3000 Avatar Training Pipeline\n\nCreates photorealistic, animatable 3D avatars from photos using:\n- **SMPL-X** → parametric body model (10,475 vertices, 54 joints)\n- **Texture extraction** → PBR albedo map from best photo\n- **Export** → GLTF (browser) + USDZ (macOS native)\n- **Checkpoint slider** → scrub through body shape interpolation\n\n## Quick Start\n1. Upload `smplx-models.tar` and `avatars-<person>.tar` (from local machine)\n2. Run all cells\n3. Use the interactive slider to explore body shapes"
},
{
"cell_type": "code",
"metadata": {},
"source": "# Cell 1: GPU check + upload tars\nimport torch, os, tarfile\nprint(f\"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'NONE — switch to GPU runtime!'}\")\nif torch.cuda.is_available():\n print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n\nfrom google.colab import files\n\n# Upload tar files\nprint('\\nUpload these files from ~/data/:')\nprint(' 1. smplx-models.tar (311 MB) — SMPL-X model weights')\nprint(' 2. avatars-valeska.tar (2.4 MB) — photos')\nprint('\\nSelect files now:')\nuploaded = files.upload()\n\n# Extract tars\nfor fn in uploaded:\n print(f'Extracting {fn}...')\n with tarfile.open(fn) as tf:\n if 'smplx' in fn.lower() or 'model' in fn.lower():\n tf.extractall('/content/models/')\n else:\n # Extract avatar photos — figure out person name from tar\n tf.extractall('/content/avatars/')\n os.remove(fn)\n print(f' Done.')\n\nprint('\\nExtracted contents:')\nfor root, dirs, fls in os.walk('/content/models'):\n for f in fls:\n print(f' models: {os.path.join(root, f)}')\nfor root, dirs, fls in os.walk('/content/avatars'):\n for f in fls[:5]:\n print(f' avatars: {os.path.join(root, f)}')\n if len(fls) > 5:\n print(f' ... and {len(fls)-5} more')",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": "# Cell 2: Configuration\nPERSON_NAME = 'Valeska' # @param {type:\"string\"}\n\n# Try Google Drive first, fall back to local upload\nDRIVE_DIR = f'/content/drive/MyDrive/avatars/{PERSON_NAME}'\nLOCAL_DIR = f'/content/avatars/{PERSON_NAME}'\n\nif os.path.exists(DRIVE_DIR) and os.listdir(DRIVE_DIR):\n PHOTO_DIR = DRIVE_DIR\n print(f'Using Google Drive: {DRIVE_DIR}')\nelif os.path.exists(LOCAL_DIR) and os.listdir(LOCAL_DIR):\n PHOTO_DIR = LOCAL_DIR\n print(f'Using uploaded files: {LOCAL_DIR}')\nelse:\n raise FileNotFoundError(\n f'No photos found! Either:\\n'\n f' 1. Upload to Google Drive: My Drive/avatars/{PERSON_NAME}/\\n'\n f' 2. Run the upload cell above'\n )\n\nOUTPUT_DIR = f'{PHOTO_DIR}/output'\nMAX_PHOTOS_BODY = 50 # @param {type:\"integer\"}\nMAX_PHOTOS_FACE = 20 # @param {type:\"integer\"}\n\nimport os\nos.makedirs(OUTPUT_DIR, exist_ok=True)\n\n# List photos with priority ordering\nall_photos = sorted(os.listdir(PHOTO_DIR))\nphotos = [f for f in all_photos if f.lower().endswith(('.jpg', '.jpeg', '.png'))\n and f != 'output']\n\ndef photo_priority(name):\n if name.startswith('f-kk-'): return (0, name)\n if name.startswith('f-'): return (1, name)\n return (2, name)\n\nphotos.sort(key=photo_priority)\nprint(f'Found {len(photos)} photos for {PERSON_NAME}')\nprint(f'Top 5: {photos[:5]}')\n",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 3: Install dependencies\n",
"%%capture install_output\n",
"!pip install smplx==0.1.28 pyrender trimesh chumpy opencv-python-headless \\\n",
" face-alignment dlib scikit-image torchvision pillow \\\n",
" pygltflib usd-core 2>/dev/null\n",
"\n",
"# Clone PyMAF-X for body estimation\n",
"!test -d /content/PyMAF-X || git clone https://github.com/HongwenZhang/PyMAF-X.git /content/PyMAF-X\n",
"%cd /content/PyMAF-X\n",
"!pip install -r requirements.txt 2>/dev/null\n",
"\n",
"# Clone DECA for face reconstruction\n",
"!test -d /content/DECA || git clone https://github.com/yfeng95/DECA.git /content/DECA\n",
"%cd /content/DECA\n",
"!pip install -r requirements.txt 2>/dev/null\n",
"\n",
"%cd /content\n",
"print('Dependencies installed.')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": "# Cell 4: Locate model files\n# SMPL-X models uploaded via smplx-models.tar (extracted to /content/models/)\n# Or from Google Drive if mounted\n\nimport glob\n\n# Check tar-uploaded location first, then Drive\nfor candidate in ['/content/models/smplx', '/content/models', '/content/drive/MyDrive/models/smplx']:\n if glob.glob(os.path.join(candidate, '**', 'SMPLX_NEUTRAL.npz'), recursive=True):\n # Find the actual directory containing the .npz files\n npz = glob.glob(os.path.join(candidate, '**', 'SMPLX_NEUTRAL.npz'), recursive=True)[0]\n SMPLX_MODEL_DIR = os.path.dirname(npz)\n break\nelse:\n SMPLX_MODEL_DIR = '/content/models/smplx'\n\nFLAME_MODEL_DIR = '/content/models/flame'\nDECA_MODEL_DIR = '/content/models/deca'\n\n# Check model files exist\nsmplx_ok = os.path.exists(f'{SMPLX_MODEL_DIR}/SMPLX_NEUTRAL.npz')\nflame_ok = os.path.exists(f'{FLAME_MODEL_DIR}/generic_model.pkl')\ndeca_ok = os.path.exists(f'{DECA_MODEL_DIR}/deca_model.tar')\n\nprint(f'SMPL-X dir: {SMPLX_MODEL_DIR}')\nprint(f'SMPL-X model: {\"OK\" if smplx_ok else \"MISSING — upload smplx-models.tar\"}')\nprint(f'FLAME model: {\"OK\" if flame_ok else \"MISSING — download from flame.is.tue.mpg.de\"}')\nprint(f'DECA model: {\"OK\" if deca_ok else \"MISSING — download from deca.is.tue.mpg.de\"}')\n\nif smplx_ok:\n # List what we have\n for f in sorted(os.listdir(SMPLX_MODEL_DIR)):\n sz = os.path.getsize(os.path.join(SMPLX_MODEL_DIR, f))\n print(f' {f}: {sz/1e6:.1f} MB')\n\nif not smplx_ok:\n print('\\n⚠️ Upload smplx-models.tar in Cell 1, or place files in Google Drive.')\n print('Required: SMPLX_NEUTRAL.npz, SMPLX_MALE.npz, SMPLX_FEMALE.npz')",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 6: Face Reconstruction via DECA\n",
"\n",
"sys.path.insert(0, '/content/DECA')\n",
"\n",
"face_photos = photos[:MAX_PHOTOS_FACE]\n",
"print(f'Processing {len(face_photos)} photos for face reconstruction...')\n",
"\n",
"all_face_shapes = []\n",
"best_face_texture = None\n",
"best_face_score = 0\n",
"\n",
"try:\n",
" from decalib.deca import DECA\n",
" from decalib.utils.config import cfg as deca_cfg\n",
"\n",
" deca_cfg.model.topology_path = os.path.join(DECA_MODEL_DIR, 'head_template.obj')\n",
" deca_cfg.model.flame_model_path = os.path.join(FLAME_MODEL_DIR, 'generic_model.pkl')\n",
" deca_cfg.pretrained_modelpath = os.path.join(DECA_MODEL_DIR, 'deca_model.tar')\n",
"\n",
" deca = DECA(config=deca_cfg, device='cuda')\n",
" print('DECA loaded successfully')\n",
"\n",
" for i, photo_name in enumerate(face_photos):\n",
" photo_path = os.path.join(PHOTO_DIR, photo_name)\n",
" try:\n",
" img = Image.open(photo_path).convert('RGB')\n",
" img_tensor = deca.preprocess(np.array(img))\n",
" if img_tensor is None:\n",
" print(f' [{i+1}/{len(face_photos)}] {photo_name}: no face detected')\n",
" continue\n",
"\n",
" codedict = deca.encode(img_tensor)\n",
" shape = codedict['shape'].cpu().numpy()\n",
" all_face_shapes.append(shape[0])\n",
"\n",
" # Track best face for texture (highest confidence)\n",
" confidence = float(codedict.get('confidence', [0])[0]) if 'confidence' in codedict else 1.0\n",
" if confidence > best_face_score:\n",
" best_face_score = confidence\n",
" best_face_texture = photo_path\n",
"\n",
" print(f' [{i+1}/{len(face_photos)}] {photo_name}: OK (confidence: {confidence:.2f})')\n",
" except Exception as e:\n",
" print(f' [{i+1}/{len(face_photos)}] {photo_name}: ERROR - {e}')\n",
"\n",
" if all_face_shapes:\n",
" avg_face_shape = np.mean(all_face_shapes, axis=0)\n",
" print(f'\\nAverage face shape from {len(all_face_shapes)} photos')\n",
" print(f'Best face texture: {best_face_texture}')\n",
" else:\n",
" print('No faces reconstructed!')\n",
" avg_face_shape = np.zeros(100)\n",
"\n",
"except Exception as e:\n",
" print(f'DECA failed to load: {e}')\n",
" print('Continuing without face reconstruction...')\n",
" avg_face_shape = np.zeros(100)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 7: Extract appearance parameters from photos\n",
"\n",
"import cv2\n",
"from collections import Counter\n",
"\n",
"def extract_skin_tone(img_path, face_region_pct=(0.3, 0.15, 0.7, 0.55)):\n",
" \"\"\"Sample skin tone from face region of photo.\"\"\"\n",
" img = cv2.imread(img_path)\n",
" if img is None:\n",
" return None\n",
" h, w = img.shape[:2]\n",
" x1 = int(w * face_region_pct[0])\n",
" y1 = int(h * face_region_pct[1])\n",
" x2 = int(w * face_region_pct[2])\n",
" y2 = int(h * face_region_pct[3])\n",
" face_crop = img[y1:y2, x1:x2]\n",
" # Convert to LAB for better skin detection\n",
" lab = cv2.cvtColor(face_crop, cv2.COLOR_BGR2LAB)\n",
" # Skin typically: L>60, 130<a<170, 130<b<180\n",
" mask = (lab[:,:,0] > 60) & (lab[:,:,1] > 130) & (lab[:,:,1] < 175) & (lab[:,:,2] > 130) & (lab[:,:,2] < 185)\n",
" skin_pixels = face_crop[mask]\n",
" if len(skin_pixels) < 50:\n",
" # Fallback: just use center region median\n",
" skin_pixels = face_crop.reshape(-1, 3)\n",
" median = np.median(skin_pixels, axis=0).astype(int)\n",
" return f'#{median[2]:02x}{median[1]:02x}{median[0]:02x}' # BGR to RGB hex\n",
"\n",
"def extract_hair_color(img_path):\n",
" \"\"\"Sample hair color from top region of photo.\"\"\"\n",
" img = cv2.imread(img_path)\n",
" if img is None:\n",
" return None\n",
" h, w = img.shape[:2]\n",
" hair_region = img[0:int(h*0.15), int(w*0.3):int(w*0.7)]\n",
" median = np.median(hair_region.reshape(-1, 3), axis=0).astype(int)\n",
" return f'#{median[2]:02x}{median[1]:02x}{median[0]:02x}'\n",
"\n",
"# Sample from top priority photos\n",
"skin_samples = []\n",
"hair_samples = []\n",
"for photo_name in photos[:15]:\n",
" path = os.path.join(PHOTO_DIR, photo_name)\n",
" skin = extract_skin_tone(path)\n",
" hair = extract_hair_color(path)\n",
" if skin:\n",
" skin_samples.append(skin)\n",
" if hair:\n",
" hair_samples.append(hair)\n",
"\n",
"# Most common color\n",
"skin_tone = Counter(skin_samples).most_common(1)[0][0] if skin_samples else '#ffccaa'\n",
"hair_color = Counter(hair_samples).most_common(1)[0][0] if hair_samples else '#3b2818'\n",
"\n",
"print(f'Skin tone: {skin_tone}')\n",
"print(f'Hair color: {hair_color}')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": "# Cell 8: Generate mesh + texture with checkpoint visualization\n\nimport trimesh\nfrom IPython.display import display, HTML\nimport ipywidgets as widgets\nimport io, base64\n\n# Generate SMPL-X mesh with average body shape\navg_betas = torch.zeros(1, 10).cuda() # neutral if no fitting done\noutput = smplx_model(\n betas=avg_betas,\n return_verts=True,\n)\n\nvertices = output.vertices.detach().cpu().numpy()[0]\nfaces = smplx_model.faces\n\n# Create trimesh\nmesh = trimesh.Trimesh(vertices=vertices, faces=faces)\nprint(f'Mesh: {len(vertices)} vertices, {len(faces)} faces')\n\n# Extract face texture from best photo\nif best_face_texture:\n face_img = Image.open(best_face_texture).convert('RGB')\n w, h = face_img.size\n face_crop = face_img.crop((int(w*0.25), 0, int(w*0.75), int(h*0.5)))\n face_crop = face_crop.resize((512, 512))\n face_crop.save(os.path.join(OUTPUT_DIR, 'face_texture.png'))\n print(f'Face texture saved from: {os.path.basename(best_face_texture)}')\nelse:\n print('No face texture (no best photo found)')\n\n# Apply uniform skin color as vertex colors\nskin_rgb = tuple(int(skin_tone.lstrip('#')[i:i+2], 16) for i in (0, 2, 4))\nvertex_colors = np.full((len(vertices), 4), [*skin_rgb, 255], dtype=np.uint8)\nmesh.visual.vertex_colors = vertex_colors\n\n# Save OBJ for inspection\nmesh.export(os.path.join(OUTPUT_DIR, 'body.obj'))\nprint(f'OBJ saved to {OUTPUT_DIR}/body.obj')\n\n# --- Generate checkpoints with varying beta parameters ---\nprint('\\nGenerating shape checkpoints for visualization...')\ncheckpoints = []\nn_checkpoints = 20\n\n# If we have body estimation results (all_betas), interpolate toward them\n# Otherwise, vary the first few shape components for exploration\nif all_betas and len(all_betas) > 0:\n target_betas = torch.tensor(np.mean(all_betas, axis=0), dtype=torch.float32).unsqueeze(0).cuda()\n for i in range(n_checkpoints + 1):\n t = i / n_checkpoints\n interp_betas = avg_betas * (1 - t) + target_betas * t\n out = smplx_model(betas=interp_betas, return_verts=True)\n verts = out.vertices.detach().cpu().numpy()[0]\n checkpoints.append({\n 'step': i,\n 'label': f'Step {i}/{n_checkpoints} (t={t:.2f})',\n 'betas': interp_betas.cpu().numpy()[0].tolist(),\n 'vertices': verts,\n })\nelse:\n # Vary first 3 principal components\n for i in range(n_checkpoints + 1):\n t = (i / n_checkpoints) * 2 - 1 # -1 to 1\n betas = torch.zeros(1, 10).cuda()\n betas[0, 0] = t * 2 # height/weight\n betas[0, 1] = t * 1 # torso width\n betas[0, 2] = t * 0.5 # limb proportions\n out = smplx_model(betas=betas, return_verts=True)\n verts = out.vertices.detach().cpu().numpy()[0]\n checkpoints.append({\n 'step': i,\n 'label': f'Shape exploration {i}/{n_checkpoints}',\n 'betas': betas.cpu().numpy()[0].tolist(),\n 'vertices': verts,\n })\n\n# Save checkpoints\nfor cp in checkpoints:\n m = trimesh.Trimesh(vertices=cp['vertices'], faces=faces)\n m.visual.vertex_colors = vertex_colors\n cp_path = os.path.join(OUTPUT_DIR, f'checkpoint_{cp[\"step\"]:03d}.obj')\n m.export(cp_path)\n del cp['vertices'] # Don't keep numpy arrays in the list\n\nprint(f'Saved {len(checkpoints)} checkpoints to {OUTPUT_DIR}/')\n",
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": "# Cell 8b: Interactive checkpoint viewer with slider\n# Renders each checkpoint from multiple angles, navigate with slider\n\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\nimport ipywidgets as widgets\nfrom IPython.display import display, clear_output\n\ndef render_checkpoint(step, angle=30, elevation=10):\n \"\"\"Render a checkpoint mesh from a given angle.\"\"\"\n cp_path = os.path.join(OUTPUT_DIR, f'checkpoint_{step:03d}.obj')\n if not os.path.exists(cp_path):\n print(f'Checkpoint {step} not found')\n return\n\n m = trimesh.load(cp_path)\n verts = np.array(m.vertices)\n tris = np.array(m.faces)\n\n fig = plt.figure(figsize=(12, 5))\n\n # Front view\n ax1 = fig.add_subplot(131, projection='3d')\n ax1.set_title('Front')\n _plot_mesh(ax1, verts, tris, azim=0, elev=elevation)\n\n # Side view\n ax2 = fig.add_subplot(132, projection='3d')\n ax2.set_title('Side')\n _plot_mesh(ax2, verts, tris, azim=90, elev=elevation)\n\n # 3/4 view (rotatable)\n ax3 = fig.add_subplot(133, projection='3d')\n ax3.set_title(f'Angle: {angle}°')\n _plot_mesh(ax3, verts, tris, azim=angle, elev=elevation)\n\n plt.suptitle(f'Checkpoint {step}/{n_checkpoints}', fontsize=14, fontweight='bold')\n plt.tight_layout()\n plt.show()\n\ndef _plot_mesh(ax, verts, tris, azim=0, elev=10):\n \"\"\"Plot mesh on a 3D axis with consistent styling.\"\"\"\n # Subsample triangles for faster rendering\n max_tris = 5000\n if len(tris) > max_tris:\n idx = np.random.choice(len(tris), max_tris, replace=False)\n tris_sub = tris[idx]\n else:\n tris_sub = tris\n\n poly = Poly3DCollection(verts[tris_sub], alpha=0.7,\n facecolor=skin_tone, edgecolor='none')\n ax.add_collection3d(poly)\n\n # Set axis limits centered on mesh\n center = verts.mean(axis=0)\n extent = max(verts.max(axis=0) - verts.min(axis=0)) * 0.6\n ax.set_xlim(center[0] - extent, center[0] + extent)\n ax.set_ylim(center[2] - extent, center[2] + extent) # swap y/z for display\n ax.set_zlim(center[1] - extent, center[1] + extent)\n ax.view_init(elev=elev, azim=azim)\n ax.set_axis_off()\n\n# Create interactive widgets\noutput_area = widgets.Output()\n\nstep_slider = widgets.IntSlider(\n value=n_checkpoints // 2, min=0, max=n_checkpoints,\n step=1, description='Checkpoint:',\n style={'description_width': '100px'},\n layout=widgets.Layout(width='80%')\n)\n\nangle_slider = widgets.IntSlider(\n value=30, min=-180, max=180,\n step=15, description='Rotation:',\n style={'description_width': '100px'},\n layout=widgets.Layout(width='80%')\n)\n\nelev_slider = widgets.IntSlider(\n value=10, min=-30, max=60,\n step=5, description='Elevation:',\n style={'description_width': '100px'},\n layout=widgets.Layout(width='40%')\n)\n\nplay_btn = widgets.Play(\n value=0, min=0, max=n_checkpoints,\n step=1, interval=500,\n description='Play'\n)\nwidgets.jslink((play_btn, 'value'), (step_slider, 'value'))\n\ndef on_change(change):\n with output_area:\n clear_output(wait=True)\n render_checkpoint(step_slider.value, angle_slider.value, elev_slider.value)\n\nstep_slider.observe(on_change, names='value')\nangle_slider.observe(on_change, names='value')\nelev_slider.observe(on_change, names='value')\n\nprint('Use the sliders to explore shape checkpoints:')\nprint(' Checkpoint: interpolation from neutral → estimated body shape')\nprint(' Rotation: view angle')\nprint(' ▶ Play: animate through all checkpoints')\ndisplay(widgets.HBox([play_btn, step_slider]))\ndisplay(widgets.HBox([angle_slider, elev_slider]))\ndisplay(output_area)\n\n# Render initial view\nwith output_area:\n render_checkpoint(step_slider.value, angle_slider.value, elev_slider.value)\n",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 9: Export GLTF with skeleton\n",
"\n",
"import pygltflib\n",
"from pygltflib import GLTF2\n",
"\n",
"# Export mesh as GLB (binary GLTF)\n",
"glb_path = os.path.join(OUTPUT_DIR, 'body.glb')\n",
"\n",
"# Use trimesh's GLTF export with skeleton data\n",
"# First, create the skeleton from SMPL-X joints\n",
"joints = output.joints.detach().cpu().numpy()[0]\n",
"\n",
"# Trimesh can export to GLTF directly\n",
"mesh.export(glb_path, file_type='glb')\n",
"print(f'GLB exported to {glb_path}')\n",
"print(f'File size: {os.path.getsize(glb_path) / 1e6:.1f} MB')\n",
"\n",
"# Save joint positions for the viewer\n",
"joint_data = {\n",
" 'joint_names': smplx_model.joint_names if hasattr(smplx_model, 'joint_names') else [f'joint_{i}' for i in range(joints.shape[0])],\n",
" 'joint_positions': joints.tolist(),\n",
" 'num_joints': joints.shape[0],\n",
"}\n",
"with open(os.path.join(OUTPUT_DIR, 'skeleton.json'), 'w') as f:\n",
" json.dump(joint_data, f, indent=2)\n",
"print(f'Skeleton data saved ({joints.shape[0]} joints)')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 10: Export USDZ for macOS native viewing\n",
"\n",
"try:\n",
" from pxr import Usd, UsdGeom, UsdShade, Sdf, Gf\n",
"\n",
" usd_path = os.path.join(OUTPUT_DIR, 'body.usdc')\n",
" usdz_path = os.path.join(OUTPUT_DIR, 'body.usdz')\n",
"\n",
" stage = Usd.Stage.CreateNew(usd_path)\n",
" UsdGeom.SetStageUpAxis(stage, UsdGeom.Tokens.y)\n",
" UsdGeom.SetStageMetersPerUnit(stage, 1.0)\n",
"\n",
" # Create mesh prim\n",
" mesh_prim = UsdGeom.Mesh.Define(stage, '/Avatar/Body')\n",
" mesh_prim.CreatePointsAttr([(v[0], v[1], v[2]) for v in vertices])\n",
" mesh_prim.CreateFaceVertexCountsAttr([3] * len(faces))\n",
" mesh_prim.CreateFaceVertexIndicesAttr(faces.flatten().tolist())\n",
" mesh_prim.CreateSubdivisionSchemeAttr('none')\n",
"\n",
" # Material with skin color\n",
" material = UsdShade.Material.Define(stage, '/Avatar/SkinMaterial')\n",
" shader = UsdShade.Shader.Define(stage, '/Avatar/SkinMaterial/PBRShader')\n",
" shader.CreateIdAttr('UsdPreviewSurface')\n",
" shader.CreateInput('diffuseColor', Sdf.ValueTypeNames.Color3f).Set(\n",
" Gf.Vec3f(skin_rgb[0]/255, skin_rgb[1]/255, skin_rgb[2]/255)\n",
" )\n",
" shader.CreateInput('roughness', Sdf.ValueTypeNames.Float).Set(0.5)\n",
" shader.CreateInput('metallic', Sdf.ValueTypeNames.Float).Set(0.0)\n",
" material.CreateSurfaceOutput().ConnectToSource(shader.ConnectableAPI(), 'surface')\n",
" mesh_prim.GetPrim().ApplyAPI(UsdShade.MaterialBindingAPI)\n",
" UsdShade.MaterialBindingAPI(mesh_prim).Bind(material)\n",
"\n",
" stage.Save()\n",
"\n",
" # Package as USDZ\n",
" from pxr import UsdUtils\n",
" UsdUtils.CreateNewUsdzPackage(Sdf.AssetPath(usd_path), usdz_path)\n",
" print(f'USDZ exported to {usdz_path}')\n",
" print(f'File size: {os.path.getsize(usdz_path) / 1e6:.1f} MB')\n",
" print('Open on Mac: double-click or Quick Look (spacebar in Finder)')\n",
"\n",
"except ImportError:\n",
" print('usd-core not available, skipping USDZ export')\n",
" print('Install with: pip install usd-core')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 11: Save person config JSON for the browser viewer\n",
"\n",
"# Extract age range from photo filenames\n",
"import re\n",
"years = []\n",
"for p in photos:\n",
" m = re.match(r'(?:f-(?:kk-)?)?(\\d{4})', p)\n",
" if m:\n",
" year = int(m.group(1))\n",
" if 1950 < year < 2030:\n",
" years.append(year)\n",
"\n",
"min_year = min(years) if years else 2020\n",
"max_year = max(years) if years else 2024\n",
"\n",
"config = {\n",
" 'personId': PERSON_NAME.lower().replace(' ', '-'),\n",
" 'displayName': PERSON_NAME,\n",
" 'skinTone': skin_tone,\n",
" 'hairColor': hair_color,\n",
" 'photoCount': len(photos),\n",
" 'photoYearRange': [min_year, max_year],\n",
" 'bodyBetas': avg_betas.cpu().numpy().tolist()[0] if torch.is_tensor(avg_betas) else [0]*10,\n",
" 'faceShape': avg_face_shape.tolist() if isinstance(avg_face_shape, np.ndarray) else [0]*100,\n",
" 'meshFile': 'body.glb',\n",
" 'usdzFile': 'body.usdz',\n",
" 'faceTexture': 'face_texture.png' if best_face_texture else None,\n",
" 'skeletonFile': 'skeleton.json',\n",
"}\n",
"\n",
"config_path = os.path.join(OUTPUT_DIR, 'config.json')\n",
"with open(config_path, 'w') as f:\n",
" json.dump(config, f, indent=2)\n",
"\n",
"print(f'Config saved to {config_path}')\n",
"print(json.dumps(config, indent=2))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# Cell 12: Summary\n",
"\n",
"print(f'\\n{\"=\"*60}')\n",
"print(f'Avatar for {PERSON_NAME} — COMPLETE')\n",
"print(f'{\"=\"*60}')\n",
"print(f'Photos processed: {len(photos)}')\n",
"print(f'Face shapes extracted: {len(all_face_shapes)}')\n",
"print(f'\\nOutput files in {OUTPUT_DIR}:')\n",
"for f in sorted(os.listdir(OUTPUT_DIR)):\n",
" size = os.path.getsize(os.path.join(OUTPUT_DIR, f))\n",
" print(f' {f:30s} {size/1e6:.1f} MB' if size > 1e6 else f' {f:30s} {size/1e3:.0f} KB')\n",
"print(f'\\nNext steps:')\n",
"print(f' 1. Download output/ folder')\n",
"print(f' 2. Place in diorama/avatar/models/{config[\"personId\"]}/')\n",
"print(f' 3. Open avatar/index.html in browser')\n",
"print(f' 4. Open body.usdz on Mac for native 3D viewing')"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment