Codebase: diffBloch - PyTorch-based crystal structure refinement
- Input: CIF files (crystal structures)
- Output: Refined atomic positions
- Pipeline: CIF → Atoms → StructureFactorNet → BlochNet → Loss → Optimization → Refined positions
- pytest framework with good fixtures
- Some integration tests (
tests/atoms/test_atoms_integration.py) - Basic seeding:
torch.manual_seed(42),np.random.seed() - Tolerance assertions:
torch.allclose(),np.testing.assert_allclose()
- Thickness sampling:
torch.randn()inutils.py:1552 - DataLoader:
SubsetRandomSamplerinrotation_dataset.py:443 - GPU operations: Missing cuDNN determinism flags
- Optimizers: LBFGS line search, Adam momentum buffers
- Bayesian optimization:
gp_minimizein preprocess.py - Thread scheduling: ThreadPoolExecutor parallel processing
PyTorch determinism controls (PyTorch Docs):
def set_deterministic_mode(seed: int = 42):
"""Enable full determinism for testing. Use only in tests - slower."""
import torch
import numpy as np
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # For multi-GPU
# Force deterministic algorithms
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# For operations with no deterministic implementation
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"Tradeoff: Deterministic ops are slower. Good for CI, not production.
Accept that bit-identical results aren't achievable. Test within acceptable tolerance.
Tolerance tiers:
| Operation Type | Relative Tolerance | Absolute Tolerance |
|---|---|---|
| Lattice math, reciprocal space | 1e-10 | 1e-12 |
| Structure factors | 1e-6 | 1e-8 |
| Bloch wave intensities | 1e-4 | 1e-6 |
| Refined positions (single run) | 1e-3 | 1e-5 |
| End-to-end (loss convergence) | 5-15% | - |
Instead of exact values, test invariants that must hold regardless of randomness:
- Conservation laws: Total intensity conserved, positions within unit cell
- Symmetry: Output respects crystal symmetry
- Monotonicity: Loss decreases over optimization steps
- Bounds: Atomic positions in valid range, thermal parameters positive
- Dimensional consistency: Shapes match, units correct
Run multiple times and test distribution properties (ML Testing Guide):
def test_refinement_converges_statistically(n_runs=5):
"""Test that refinement converges to similar results across runs."""
final_losses = []
for seed in range(n_runs):
result = run_refinement(seed=seed)
final_losses.append(result.final_loss)
# Test distribution properties
assert np.std(final_losses) / np.mean(final_losses) < 0.1 # CV < 10%
assert all(loss < CONVERGENCE_THRESHOLD for loss in final_losses)Store reference outputs, compare with tolerance:
@pytest.fixture
def golden_reference():
"""Load pre-computed reference for comparison."""
return torch.load("tests/golden/refinement_reference.pt")
def test_refinement_matches_golden(golden_reference, deterministic_seed):
result = run_refinement(seed=deterministic_seed)
# Compare with tolerance, not exact match
assert torch.allclose(
result.positions,
golden_reference.positions,
rtol=1e-3, atol=1e-5
)Quick sanity checks that the pipeline doesn't crash:
@pytest.mark.integration
def test_pipeline_runs_without_error():
"""Smoke test: pipeline completes without exception."""
result = run_short_refinement(max_epochs=2)
assert result.final_loss is not None
assert result.final_loss < result.initial_loss # At least some improvement /\
/ \ E2E Tests (rare, slow)
/ \ - Full refinement on real CIF
/------\
/ \ Integration Tests
/ \ - Pipeline segments
/ \ - StructureFactor → Bloch → Loss
/--------------\
/ \ Unit Tests (many, fast)
/ \ - Individual functions
/ \ - Utility functions
/______________________\
- Reproducibility: Statistical (multiple runs, test distribution properties)
- Goal: Correctness validation against known structures
- Reference data: Available - known/published refined structures
- Runtime: Thorough nightly (10+ min per test acceptable)
File: tests/conftest.py (create/extend)
import pytest
import torch
import numpy as np
# Markers for test categorization
def pytest_configure(config):
config.addinivalue_line("markers", "integration: end-to-end integration tests")
config.addinivalue_line("markers", "statistical: multi-run statistical tests")
config.addinivalue_line("markers", "nightly: thorough tests for nightly CI")
@pytest.fixture
def seeded_environment():
"""Seed all RNGs for a single reproducible run."""
def _seed(seed=42):
torch.manual_seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
return _seed
@pytest.fixture
def reference_structures():
"""Load known reference structures for validation."""
# Load from tests/integration/reference_data/
passSource: shreshth/journal-review branch contains ideal reference data:
data/cspbbr3/
├── experimental_data/
│ ├── cspbbr3.cif # Ground truth (CCDC structure)
│ └── CsPbBr_P67-crot-050_dyn.cif_pets # Experimental diffraction
└── synthetic_data/synthetic_structural_data_paper/
├── cspbbr3_01_displacement.cif # 0.1 Å perturbed
├── cspbbr3_02_displacement.cif # 0.2 Å perturbed
├── cspbbr3_03_displacement.cif # 0.3 Å perturbed
└── cspbbr3_paper_data.cif_pets # Synthetic diffraction
Testing strategy with this data:
- Recovery tests: Start from displaced CIF (0.1-0.3 Å), refine back toward ground truth
- Compare: Refined positions vs
cspbbr3.cif(known correct structure) - Tolerance: RMSD < displacement amount (e.g., start at 0.2 Å, refine to < 0.1 Å)
Test data directory (create in repo):
tests/integration/
├── conftest.py
├── reference_data/
│ ├── cspbbr3/ # Copy or symlink from journal-review
│ │ ├── ground_truth.cif # Copy of experimental cspbbr3.cif
│ │ ├── displaced_01.cif # 0.1 Å perturbation
│ │ ├── displaced_02.cif # 0.2 Å perturbation
│ │ ├── diffraction.cif_pets # Diffraction data
│ │ └── metadata.yaml # Test parameters
│ └── README.md
└── test_refinement_correctness.py
metadata.yaml format:
name: "CsPbBr3 recovery from 0.2 Å displacement"
source: "shreshth/journal-review branch, CCDC structure"
ground_truth: "ground_truth.cif"
displaced_inputs:
- file: "displaced_01.cif"
initial_displacement: 0.1 # Angstroms
expected_final_rmsd: 0.05 # Should recover to < 0.05 Å
- file: "displaced_02.cif"
initial_displacement: 0.2
expected_final_rmsd: 0.08
tolerances:
position_rmsd: 0.1 # Angstroms - must improve from initial
r_factor: 0.10 # 10% R-factor tolerance
validation_runs: 3 # Number of statistical runs
max_epochs: 50 # Epochs per test runFile: tests/integration/test_refinement_correctness.py
import pytest
import torch
import numpy as np
from pathlib import Path
@pytest.mark.integration
@pytest.mark.statistical
@pytest.mark.nightly
class TestRefinementCorrectness:
"""Statistical integration tests against known structures."""
@pytest.fixture
def reference_cases(self):
"""Load all reference test cases."""
ref_dir = Path(__file__).parent / "reference_data"
cases = []
for case_dir in ref_dir.iterdir():
if case_dir.is_dir() and (case_dir / "input.cif").exists():
cases.append(case_dir)
return cases
def run_refinement_with_seed(self, case_dir: Path, seed: int):
"""Run refinement for a single seed."""
# Load config, run refinement, return results
pass
def test_converges_to_known_structure(self, reference_cases):
"""
Statistical test: Run refinement N times, verify:
1. Mean refined positions are close to reference
2. Variance across runs is acceptably low
3. All runs achieve acceptable R-factor
"""
for case_dir in reference_cases:
metadata = load_metadata(case_dir)
expected = torch.load(case_dir / "expected_positions.pt")
# Run multiple times with different seeds
n_runs = metadata.get("validation_runs", 5)
results = []
for seed in range(n_runs):
result = self.run_refinement_with_seed(case_dir, seed)
results.append(result)
# Statistical assertions
positions = torch.stack([r.positions for r in results])
mean_positions = positions.mean(dim=0)
std_positions = positions.std(dim=0)
# 1. Mean close to reference
rmsd = compute_rmsd(mean_positions, expected)
assert rmsd < metadata["tolerances"]["position_rmsd"], \
f"RMSD {rmsd:.4f} exceeds tolerance"
# 2. Low variance (convergence consistency)
max_std = std_positions.max().item()
assert max_std < 0.005, f"High variance: max_std={max_std:.4f}"
# 3. All runs achieve acceptable loss
r_factors = [r.r_factor for r in results]
assert all(r < metadata["tolerances"]["r_factor"] for r in r_factors)
def test_refinement_improves_from_perturbed(self, reference_cases):
"""
Test that refinement recovers correct structure from perturbed start.
Perturb known structure, verify refinement returns to correct answer.
"""
for case_dir in reference_cases:
expected = torch.load(case_dir / "expected_positions.pt")
# Perturb by small random displacement
for seed in range(3):
torch.manual_seed(seed)
perturbation = torch.randn_like(expected) * 0.05 # 0.05 Angstrom
perturbed = expected + perturbation
# Run refinement from perturbed start
result = self.run_refinement_from_positions(case_dir, perturbed, seed)
# Should recover close to original
rmsd = compute_rmsd(result.positions, expected)
assert rmsd < 0.02 # Within 0.02 AngstromsFile: tests/integration/test_invariants.py
@pytest.mark.integration
class TestPhysicalInvariants:
"""Test that physical invariants hold regardless of randomness."""
def test_positions_remain_in_unit_cell(self, refinement_result):
"""Fractional coordinates must be in [0, 1)."""
positions = refinement_result.fractional_positions
assert (positions >= 0).all() and (positions < 1).all()
def test_symmetry_preserved(self, refinement_result, input_structure):
"""Refined structure respects input space group symmetry."""
# Apply symmetry operations, verify equivalent positions
pass
def test_thermal_parameters_positive(self, refinement_result):
"""Thermal displacement parameters must be positive definite."""
Uij = refinement_result.thermal_parameters
eigenvalues = torch.linalg.eigvalsh(Uij)
assert (eigenvalues > 0).all()
def test_loss_monotonic_decrease(self, refinement_history):
"""Loss should generally decrease during optimization."""
losses = refinement_history.losses
# Allow some noise, but overall trend should be down
smoothed = moving_average(losses, window=10)
assert smoothed[-1] < smoothed[0] * 0.5 # At least 50% reductionAdd to pyproject.toml:
[tool.pytest.ini_options]
markers = [
"integration: end-to-end integration tests",
"statistical: multi-run statistical tests (slow)",
"nightly: thorough tests for nightly CI",
]
# Default: skip slow tests
addopts = "-m 'not nightly'"CI configurations:
- PR checks:
pytest -m "not nightly"(fast unit tests only) - Nightly:
pytest -m "nightly"(full statistical suite)
File: tests/integration/utils.py
def compute_rmsd(positions1: torch.Tensor, positions2: torch.Tensor) -> float:
"""Compute RMSD between two position tensors."""
diff = positions1 - positions2
return torch.sqrt((diff ** 2).mean()).item()
def compare_with_tolerance(actual, expected, metadata):
"""Compare results against expected with metadata-specified tolerances."""
pass
def generate_golden_reference(case_dir: Path, n_runs: int = 10):
"""
Generate golden reference from multiple converged runs.
Use median of converged runs as reference (robust to outliers).
"""
pass# Copy CsPbBr3 reference files to tests/integration/reference_data/cspbbr3/
git show origin/shreshth/journal-review:data/cspbbr3/experimental_data/cspbbr3.cif \
> tests/integration/reference_data/cspbbr3/ground_truth.cif
git show origin/shreshth/journal-review:data/cspbbr3/synthetic_data/synthetic_structural_data_paper/cspbbr3_01_displacement.cif \
> tests/integration/reference_data/cspbbr3/displaced_01.cif
git show origin/shreshth/journal-review:data/cspbbr3/synthetic_data/synthetic_structural_data_paper/cspbbr3_02_displacement.cif \
> tests/integration/reference_data/cspbbr3/displaced_02.cif
git show origin/shreshth/journal-review:data/cspbbr3/synthetic_data/synthetic_structural_data_paper/cspbbr3_paper_data.cif_pets \
> tests/integration/reference_data/cspbbr3/diffraction.cif_pets| File | Action | Purpose |
|---|---|---|
tests/conftest.py |
Extend | Add markers, fixtures |
tests/integration/conftest.py |
Create | Integration-specific fixtures |
tests/integration/reference_data/cspbbr3/ |
Create | CIF files from journal-review |
tests/integration/reference_data/cspbbr3/metadata.yaml |
Create | Test parameters |
tests/integration/test_refinement_correctness.py |
Create | Statistical correctness tests |
tests/integration/test_invariants.py |
Create | Property/invariant tests |
tests/integration/utils.py |
Create | Test helpers (RMSD, comparison) |
pyproject.toml |
Modify | Add pytest markers |
- Statistical over deterministic: Run N times (default 5), test distribution properties
- Reference-based validation: Compare against known/published structures
- Tolerance hierarchy: RMSD for positions, R-factor for fit quality, std for consistency
- Property tests as safety net: Physical invariants catch bugs even without reference data
- Nightly marker: Keep fast tests for PRs, thorough tests for nightly CI