Skip to content

Instantly share code, notes, and snippets.

@carstenbauer
Last active April 21, 2022 09:44
Show Gist options
  • Select an option

  • Save carstenbauer/cd1bbb44aba928b90c134872f7a37500 to your computer and use it in GitHub Desktop.

Select an option

Save carstenbauer/cd1bbb44aba928b90c134872f7a37500 to your computer and use it in GitHub Desktop.
using DelimitedFiles
using LinearAlgebra
using Test
function readqcd(fname::AbstractString; verify=true)
open(fname, "r") do f
header_str = readuntil(f, "END_HEADER\n")
header_mat = readdlm(IOBuffer(header_str), '=')
header = Dict(strip(header_mat[i, 1]) => header_mat[i, 2] for i in 2:size(header_mat, 1))
Nx::Int64 = header["DIMENSION_1"]
Ny::Int64 = header["DIMENSION_2"]
Nz::Int64 = header["DIMENSION_3"]
Nt::Int64 = header["DIMENSION_4"]
FLOATING_POINT = strip(header["FLOATING_POINT"])
if !(FLOATING_POINT == "IEEE32BIG")
error("Unsupported floating point format: ", FLOATING_POINT)
end
# read rest of file, i.e. U matrices in compressed form
raw_data = read(f)
n_nums = sizeof(raw_data) ÷ sizeof(Float32)
n_matrices = n_nums ÷ (2 * 6) # 6 complex nums in a matrix, 2 real nums in a complex num
if n_matrices != linkcount((Nx, Ny, Nz, Nt))
@warn("Expected a different number of U matrices...", n_matrices, linkcount((Nx, Ny, Nz, Nt)), Nx, Ny, Nz, Nt)
end
# reinterpret data as Float32
raw_numbers = reinterpret(Float32, raw_data)
# convert endianess (if necessary)
map!(ntoh, raw_numbers, raw_numbers)
check_sum = checksum(raw_numbers)
if check_sum != parse(Cuint, "0x" * strip(header["CHECKSUM"]))
@warn("Checksum doesn't match!", header["CHECKSUM"], check_sum)
end
# combine two real numbers -> complex numbers
U_data = reinterpret(ComplexF32, raw_numbers)
# reshape into cube (rowidx, colidx, Uidx) + use permutedims to go from row major to col major order
U_cube = permutedims(reshape(U_data, (3, 2, n_matrices)), (2, 1, 3))
Us = Vector{Matrix{eltype(U_cube)}}(undef, size(U_cube, 3))
@inbounds for i in eachindex(Us)
Us[i] = unitarize(@view U_cube[:, :, i])
end
if verify
if all(isSU3, Us)
@info("Resulting U matrices are ∈ SU(3). ✓")
else
@warn("Resulting U matrices appear to not be ∈ SU(3).")
end
link_trace = sum(tr ∘ real, Us) / (3 * 4 * Nx * Ny * Nz * Nt)
if link_trace ≈ header["LINK_TRACE"]
@info("Link trace check passed. ✓")
else
@warn("Link trace check failed.", link_trace, header["LINK_TRACE"])
end
end
return Us, (Nx=Nx, Ny=Ny, Nz=Nz, Nt=Nt)
end
end
function unitarize(U::AbstractMatrix)
@assert size(U) == (2, 3)
R = ones(complex(eltype(U)), 3, 3)
# copy over existing values from U
R[1:2, 1:3] .= U
# fill third row such that R is unitary
R[3, 1] = Complex(
real(R[1, 2]) * real(R[2, 3]) - imag(R[1, 2]) * imag(R[2, 3]) - (real(R[1, 3]) * real(R[2, 2]) - imag(R[1, 3]) * imag(R[2, 2])),
-(imag(R[1, 2]) * real(R[2, 3]) + real(R[1, 2]) * imag(R[2, 3])) + (imag(R[1, 3]) * real(R[2, 2]) + real(R[1, 3]) * imag(R[2, 2]))
)
R[3, 2] = Complex(
real(R[1, 3]) * real(R[2, 1]) - imag(R[1, 3]) * imag(R[2, 1]) - (real(R[1, 1]) * real(R[2, 3]) - imag(R[1, 1]) * imag(R[2, 3])),
-(imag(R[1, 3]) * real(R[2, 1]) + real(R[1, 3]) * imag(R[2, 1])) + (imag(R[1, 1]) * real(R[2, 3]) + real(R[1, 1]) * imag(R[2, 3]))
)
R[3, 3] = Complex(
real(R[1, 1]) * real(R[2, 2]) - imag(R[1, 1]) * imag(R[2, 2]) - (real(R[1, 2]) * real(R[2, 1]) - imag(R[1, 2]) * imag(R[2, 1])),
-(imag(R[1, 1]) * real(R[2, 2]) + real(R[1, 1]) * imag(R[2, 2])) + (imag(R[1, 2]) * real(R[2, 1]) + real(R[1, 2]) * imag(R[2, 1]))
)
return R
end
isunitary(U) = U * U' ≈ I
function isSU3(U)
if size(U) == (3, 3) && isunitary(U) && abs(det(U)) ≈ 1
return true
else
return false
end
end
linkcount(N::Integer; d=4) = N^d * d
linkcount(Ns::Tuple) = prod(Ns) * length(Ns)
function checksum(raw_data)
# we assume that endianess is already correct
s_uint64 = sum(reinterpret(Cuint, raw_data))
s_uint32 = Cuint(s_uint64 & typemax(Cuint)) # force convert to uint32
return s_uint32
end
function lin2cart(idx; Nx, Ny, Nz, Nt)
vol1 = Nx
vol2 = vol1 * Ny
vol3 = vol2 * Nz
vol4 = vol3 * Nt
sizeh = vol4 ÷ 2
i = idx - 1 # zero-based indices for computation
mu, siteidx = divrem(i, vol4)
# figure out the parity:
par_int, normId = divrem(siteidx, sizeh)
par = Bool(par_int)
# par now contains site/sizeh (integer division), so it should be 0 (even) or 1 (odd).
# normInd contains the remainder.
# Adjacent odd and even sites will have the same remainder.
# Now think of an interlaced list of all even and all odd sites, such that the entries alternate
# between even and odd sites. Since adjacent sites have the same remainder, the remainder functions as
# the index of the *pairs* of adjacent sites.
# The next step is now to double this remainder so that we can work with it as an index for the single sites
# and not the pairs.
normId *= 2
# Now get the slower running coordinates y,z,t:
# To get these, we simply integer-divide the index by the product of all faster running lattice extents,
# and then use the remainder as the index for the next-faster coordinate and so on.
t, tmp = divrem(normId, vol3)
z, tmp = divrem(tmp, vol2)
y, x = divrem(tmp, vol1)
# One problem remains: since we doubled the remainder and since the lattice extents have to be even,
# x is now also always even, which is of course not correct.
# We may need to correct it to be odd, depending on the supposed parity we found in the beginning,
# and depending on whether y+z+t is even or odd:
if !isodd(x)
# odd parity but y+z+t is even, so x should be odd
# or
# even parity but y+z+t is odd, so x should be odd
if (par && !isodd(y + z + t)) || (!par && isodd(y + z + t))
x += 1
end
end
# Note that we always stay inside of a pair of adjacent sites when incrementing x here.
return (x=x + 1, y=y + 1, z=z + 1, t=t + 1, mu=mu + 1)
end
function cart2lin(cidx; Nx, Ny, Nz, Nt)
length(cidx) == 5 || error("Cartesian index must have x,y,z,t,mu components.")
x, y, z, t, mu = cidx .- 1
vol1 = Nx
vol2 = vol1 * Ny
vol3 = vol2 * Nz
vol4 = vol3 * Nt
sizeh = vol4 ÷ 2
lindex = (x + y * vol1 + z * vol2 + t * vol3) ÷ 2 + sizeh * isodd(x + y + z + t) + mu * vol4
return lindex + 1
end
cart2lin(cidx::NamedTuple; kwargs...) = cart2lin((cidx.x, cidx.y, cidx.z, cidx.t, cidx.mu); kwargs...)
function test_indexconversion(Nmax=10)
@testset "lin2cart / cart2lin" begin
for x in 1:Nmax, y in 1:Nmax, z in 1:Nmax, t in 1:Nmax, mu in 1:4
cords = (; x, y, z, t, mu)
@test lin2cart(cart2lin(cords; Ns...); Ns...) == cords
end
end
end
# main
input_file = "l20t20b06498a_nersc.302500"
Us, Ns = readqcd(input_file; verify=true);
Nx, Ny, Nz, Nt = Ns
# indices = [lin2cart(i; Ns...) for i in 1:length(Us)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment