Skip to content

Instantly share code, notes, and snippets.

@imaitland
Created January 6, 2026 20:50
Show Gist options
  • Select an option

  • Save imaitland/f2b523bfd9ff6d9112d4344567c1e5e5 to your computer and use it in GitHub Desktop.

Select an option

Save imaitland/f2b523bfd9ff6d9112d4344567c1e5e5 to your computer and use it in GitHub Desktop.
[claude-code] diffBloch Integration Testing Strategy for ML with Non-Determinism

Integration Testing Strategy for ML Codebase with Non-Determinism

Context

Codebase: diffBloch - PyTorch-based crystal structure refinement

  • Input: CIF files (crystal structures)
  • Output: Refined atomic positions
  • Pipeline: CIF → Atoms → StructureFactorNet → BlochNet → Loss → Optimization → Refined positions

Current State

Existing Testing

  • pytest framework with good fixtures
  • Some integration tests (tests/atoms/test_atoms_integration.py)
  • Basic seeding: torch.manual_seed(42), np.random.seed()
  • Tolerance assertions: torch.allclose(), np.testing.assert_allclose()

Non-Determinism Sources Identified

  1. Thickness sampling: torch.randn() in utils.py:1552
  2. DataLoader: SubsetRandomSampler in rotation_dataset.py:443
  3. GPU operations: Missing cuDNN determinism flags
  4. Optimizers: LBFGS line search, Adam momentum buffers
  5. Bayesian optimization: gp_minimize in preprocess.py
  6. Thread scheduling: ThreadPoolExecutor parallel processing

Research Findings: Best Practices for ML Integration Testing

Strategy 1: Deterministic Mode (For CI/Testing Only)

PyTorch determinism controls (PyTorch Docs):

def set_deterministic_mode(seed: int = 42):
    """Enable full determinism for testing. Use only in tests - slower."""
    import torch
    import numpy as np
    import random

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU

    # Force deterministic algorithms
    torch.use_deterministic_algorithms(True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # For operations with no deterministic implementation
    import os
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

Tradeoff: Deterministic ops are slower. Good for CI, not production.

Strategy 2: Tolerance-Based Testing (Recommended Primary Approach)

Accept that bit-identical results aren't achievable. Test within acceptable tolerance.

Tolerance tiers:

Operation Type Relative Tolerance Absolute Tolerance
Lattice math, reciprocal space 1e-10 1e-12
Structure factors 1e-6 1e-8
Bloch wave intensities 1e-4 1e-6
Refined positions (single run) 1e-3 1e-5
End-to-end (loss convergence) 5-15% -

Strategy 3: Property-Based Testing

Instead of exact values, test invariants that must hold regardless of randomness:

  1. Conservation laws: Total intensity conserved, positions within unit cell
  2. Symmetry: Output respects crystal symmetry
  3. Monotonicity: Loss decreases over optimization steps
  4. Bounds: Atomic positions in valid range, thermal parameters positive
  5. Dimensional consistency: Shapes match, units correct

Strategy 4: Statistical/Distribution Testing

Run multiple times and test distribution properties (ML Testing Guide):

def test_refinement_converges_statistically(n_runs=5):
    """Test that refinement converges to similar results across runs."""
    final_losses = []
    for seed in range(n_runs):
        result = run_refinement(seed=seed)
        final_losses.append(result.final_loss)

    # Test distribution properties
    assert np.std(final_losses) / np.mean(final_losses) < 0.1  # CV < 10%
    assert all(loss < CONVERGENCE_THRESHOLD for loss in final_losses)

Strategy 5: Golden/Snapshot Testing with Tolerance

Store reference outputs, compare with tolerance:

@pytest.fixture
def golden_reference():
    """Load pre-computed reference for comparison."""
    return torch.load("tests/golden/refinement_reference.pt")

def test_refinement_matches_golden(golden_reference, deterministic_seed):
    result = run_refinement(seed=deterministic_seed)

    # Compare with tolerance, not exact match
    assert torch.allclose(
        result.positions,
        golden_reference.positions,
        rtol=1e-3, atol=1e-5
    )

Strategy 6: Smoke Testing / Fast Integration Tests

Quick sanity checks that the pipeline doesn't crash:

@pytest.mark.integration
def test_pipeline_runs_without_error():
    """Smoke test: pipeline completes without exception."""
    result = run_short_refinement(max_epochs=2)
    assert result.final_loss is not None
    assert result.final_loss < result.initial_loss  # At least some improvement

Recommended Testing Pyramid for diffBloch

                    /\
                   /  \  E2E Tests (rare, slow)
                  /    \ - Full refinement on real CIF
                 /------\
                /        \ Integration Tests
               /          \ - Pipeline segments
              /            \ - StructureFactor → Bloch → Loss
             /--------------\
            /                \ Unit Tests (many, fast)
           /                  \ - Individual functions
          /                    \ - Utility functions
         /______________________\

User Requirements (Clarified)

  • Reproducibility: Statistical (multiple runs, test distribution properties)
  • Goal: Correctness validation against known structures
  • Reference data: Available - known/published refined structures
  • Runtime: Thorough nightly (10+ min per test acceptable)

Implementation Plan

Phase 1: Testing Infrastructure (tests/conftest.py)

File: tests/conftest.py (create/extend)

import pytest
import torch
import numpy as np

# Markers for test categorization
def pytest_configure(config):
    config.addinivalue_line("markers", "integration: end-to-end integration tests")
    config.addinivalue_line("markers", "statistical: multi-run statistical tests")
    config.addinivalue_line("markers", "nightly: thorough tests for nightly CI")

@pytest.fixture
def seeded_environment():
    """Seed all RNGs for a single reproducible run."""
    def _seed(seed=42):
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    return _seed

@pytest.fixture
def reference_structures():
    """Load known reference structures for validation."""
    # Load from tests/integration/reference_data/
    pass

Phase 2: Reference Data Setup

Source: shreshth/journal-review branch contains ideal reference data:

data/cspbbr3/
├── experimental_data/
│   ├── cspbbr3.cif                      # Ground truth (CCDC structure)
│   └── CsPbBr_P67-crot-050_dyn.cif_pets # Experimental diffraction
└── synthetic_data/synthetic_structural_data_paper/
    ├── cspbbr3_01_displacement.cif      # 0.1 Å perturbed
    ├── cspbbr3_02_displacement.cif      # 0.2 Å perturbed
    ├── cspbbr3_03_displacement.cif      # 0.3 Å perturbed
    └── cspbbr3_paper_data.cif_pets      # Synthetic diffraction

Testing strategy with this data:

  1. Recovery tests: Start from displaced CIF (0.1-0.3 Å), refine back toward ground truth
  2. Compare: Refined positions vs cspbbr3.cif (known correct structure)
  3. Tolerance: RMSD < displacement amount (e.g., start at 0.2 Å, refine to < 0.1 Å)

Test data directory (create in repo):

tests/integration/
├── conftest.py
├── reference_data/
│   ├── cspbbr3/                         # Copy or symlink from journal-review
│   │   ├── ground_truth.cif             # Copy of experimental cspbbr3.cif
│   │   ├── displaced_01.cif             # 0.1 Å perturbation
│   │   ├── displaced_02.cif             # 0.2 Å perturbation
│   │   ├── diffraction.cif_pets         # Diffraction data
│   │   └── metadata.yaml                # Test parameters
│   └── README.md
└── test_refinement_correctness.py

metadata.yaml format:

name: "CsPbBr3 recovery from 0.2 Å displacement"
source: "shreshth/journal-review branch, CCDC structure"
ground_truth: "ground_truth.cif"
displaced_inputs:
  - file: "displaced_01.cif"
    initial_displacement: 0.1  # Angstroms
    expected_final_rmsd: 0.05  # Should recover to < 0.05 Å
  - file: "displaced_02.cif"
    initial_displacement: 0.2
    expected_final_rmsd: 0.08
tolerances:
  position_rmsd: 0.1   # Angstroms - must improve from initial
  r_factor: 0.10       # 10% R-factor tolerance
validation_runs: 3     # Number of statistical runs
max_epochs: 50         # Epochs per test run

Phase 3: Statistical Integration Tests

File: tests/integration/test_refinement_correctness.py

import pytest
import torch
import numpy as np
from pathlib import Path

@pytest.mark.integration
@pytest.mark.statistical
@pytest.mark.nightly
class TestRefinementCorrectness:
    """Statistical integration tests against known structures."""

    @pytest.fixture
    def reference_cases(self):
        """Load all reference test cases."""
        ref_dir = Path(__file__).parent / "reference_data"
        cases = []
        for case_dir in ref_dir.iterdir():
            if case_dir.is_dir() and (case_dir / "input.cif").exists():
                cases.append(case_dir)
        return cases

    def run_refinement_with_seed(self, case_dir: Path, seed: int):
        """Run refinement for a single seed."""
        # Load config, run refinement, return results
        pass

    def test_converges_to_known_structure(self, reference_cases):
        """
        Statistical test: Run refinement N times, verify:
        1. Mean refined positions are close to reference
        2. Variance across runs is acceptably low
        3. All runs achieve acceptable R-factor
        """
        for case_dir in reference_cases:
            metadata = load_metadata(case_dir)
            expected = torch.load(case_dir / "expected_positions.pt")

            # Run multiple times with different seeds
            n_runs = metadata.get("validation_runs", 5)
            results = []
            for seed in range(n_runs):
                result = self.run_refinement_with_seed(case_dir, seed)
                results.append(result)

            # Statistical assertions
            positions = torch.stack([r.positions for r in results])
            mean_positions = positions.mean(dim=0)
            std_positions = positions.std(dim=0)

            # 1. Mean close to reference
            rmsd = compute_rmsd(mean_positions, expected)
            assert rmsd < metadata["tolerances"]["position_rmsd"], \
                f"RMSD {rmsd:.4f} exceeds tolerance"

            # 2. Low variance (convergence consistency)
            max_std = std_positions.max().item()
            assert max_std < 0.005, f"High variance: max_std={max_std:.4f}"

            # 3. All runs achieve acceptable loss
            r_factors = [r.r_factor for r in results]
            assert all(r < metadata["tolerances"]["r_factor"] for r in r_factors)

    def test_refinement_improves_from_perturbed(self, reference_cases):
        """
        Test that refinement recovers correct structure from perturbed start.
        Perturb known structure, verify refinement returns to correct answer.
        """
        for case_dir in reference_cases:
            expected = torch.load(case_dir / "expected_positions.pt")

            # Perturb by small random displacement
            for seed in range(3):
                torch.manual_seed(seed)
                perturbation = torch.randn_like(expected) * 0.05  # 0.05 Angstrom
                perturbed = expected + perturbation

                # Run refinement from perturbed start
                result = self.run_refinement_from_positions(case_dir, perturbed, seed)

                # Should recover close to original
                rmsd = compute_rmsd(result.positions, expected)
                assert rmsd < 0.02  # Within 0.02 Angstroms

Phase 4: Property-Based Invariant Tests

File: tests/integration/test_invariants.py

@pytest.mark.integration
class TestPhysicalInvariants:
    """Test that physical invariants hold regardless of randomness."""

    def test_positions_remain_in_unit_cell(self, refinement_result):
        """Fractional coordinates must be in [0, 1)."""
        positions = refinement_result.fractional_positions
        assert (positions >= 0).all() and (positions < 1).all()

    def test_symmetry_preserved(self, refinement_result, input_structure):
        """Refined structure respects input space group symmetry."""
        # Apply symmetry operations, verify equivalent positions
        pass

    def test_thermal_parameters_positive(self, refinement_result):
        """Thermal displacement parameters must be positive definite."""
        Uij = refinement_result.thermal_parameters
        eigenvalues = torch.linalg.eigvalsh(Uij)
        assert (eigenvalues > 0).all()

    def test_loss_monotonic_decrease(self, refinement_history):
        """Loss should generally decrease during optimization."""
        losses = refinement_history.losses
        # Allow some noise, but overall trend should be down
        smoothed = moving_average(losses, window=10)
        assert smoothed[-1] < smoothed[0] * 0.5  # At least 50% reduction

Phase 5: Pytest Configuration

Add to pyproject.toml:

[tool.pytest.ini_options]
markers = [
    "integration: end-to-end integration tests",
    "statistical: multi-run statistical tests (slow)",
    "nightly: thorough tests for nightly CI",
]

# Default: skip slow tests
addopts = "-m 'not nightly'"

CI configurations:

  • PR checks: pytest -m "not nightly" (fast unit tests only)
  • Nightly: pytest -m "nightly" (full statistical suite)

Phase 6: Helper Utilities

File: tests/integration/utils.py

def compute_rmsd(positions1: torch.Tensor, positions2: torch.Tensor) -> float:
    """Compute RMSD between two position tensors."""
    diff = positions1 - positions2
    return torch.sqrt((diff ** 2).mean()).item()

def compare_with_tolerance(actual, expected, metadata):
    """Compare results against expected with metadata-specified tolerances."""
    pass

def generate_golden_reference(case_dir: Path, n_runs: int = 10):
    """
    Generate golden reference from multiple converged runs.
    Use median of converged runs as reference (robust to outliers).
    """
    pass

Implementation Steps

Step 1: Copy Reference Data from shreshth/journal-review

# Copy CsPbBr3 reference files to tests/integration/reference_data/cspbbr3/
git show origin/shreshth/journal-review:data/cspbbr3/experimental_data/cspbbr3.cif \
    > tests/integration/reference_data/cspbbr3/ground_truth.cif

git show origin/shreshth/journal-review:data/cspbbr3/synthetic_data/synthetic_structural_data_paper/cspbbr3_01_displacement.cif \
    > tests/integration/reference_data/cspbbr3/displaced_01.cif

git show origin/shreshth/journal-review:data/cspbbr3/synthetic_data/synthetic_structural_data_paper/cspbbr3_02_displacement.cif \
    > tests/integration/reference_data/cspbbr3/displaced_02.cif

git show origin/shreshth/journal-review:data/cspbbr3/synthetic_data/synthetic_structural_data_paper/cspbbr3_paper_data.cif_pets \
    > tests/integration/reference_data/cspbbr3/diffraction.cif_pets

Step 2: Create Test Infrastructure

File Action Purpose
tests/conftest.py Extend Add markers, fixtures
tests/integration/conftest.py Create Integration-specific fixtures
tests/integration/reference_data/cspbbr3/ Create CIF files from journal-review
tests/integration/reference_data/cspbbr3/metadata.yaml Create Test parameters
tests/integration/test_refinement_correctness.py Create Statistical correctness tests
tests/integration/test_invariants.py Create Property/invariant tests
tests/integration/utils.py Create Test helpers (RMSD, comparison)
pyproject.toml Modify Add pytest markers

Key Design Decisions

  1. Statistical over deterministic: Run N times (default 5), test distribution properties
  2. Reference-based validation: Compare against known/published structures
  3. Tolerance hierarchy: RMSD for positions, R-factor for fit quality, std for consistency
  4. Property tests as safety net: Physical invariants catch bugs even without reference data
  5. Nightly marker: Keep fast tests for PRs, thorough tests for nightly CI

Sources

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment