Skip to content

Instantly share code, notes, and snippets.

@Hadrianneue
Last active May 22, 2025 21:08
Show Gist options
  • Select an option

  • Save Hadrianneue/6815b5cccc6280cea6ef67ca73b01a9a to your computer and use it in GitHub Desktop.

Select an option

Save Hadrianneue/6815b5cccc6280cea6ef67ca73b01a9a to your computer and use it in GitHub Desktop.
mesa fp8 hack for fsr4
diff --git a/src/amd/compiler/aco_interface.cpp b/src/amd/compiler/aco_interface.cpp
index a08f6560578..055ba4c1645 100644
--- a/src/amd/compiler/aco_interface.cpp
+++ b/src/amd/compiler/aco_interface.cpp
@@ -482,6 +482,8 @@ aco_nir_op_supports_packed_math_16bit(const nir_alu_instr* alu)
return (shader->options->force_f2f16_rtz && !nir_is_rounding_mode_rtne(execution_mode, 16)) ||
nir_is_rounding_mode_rtz(execution_mode, 16);
}
+ case nir_op_f2e4m3fn:
+ case nir_op_e4m3fn2f:
case nir_op_fadd:
case nir_op_fsub:
case nir_op_fmul:
diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp
index 8d8bc7fe746..8d1466a6dbc 100644
--- a/src/amd/compiler/aco_ir.cpp
+++ b/src/amd/compiler/aco_ir.cpp
@@ -584,6 +584,8 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2;
case aco_opcode::v_interp_p2_f16_f32_inreg:
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0;
+ case aco_opcode::v_cvt_pk_fp8_f32:
+ case aco_opcode::v_cvt_pk_bf8_f32: return idx == -1;
default:
return gfx_level >= GFX11 && (get_gfx11_true16_mask(op) & BITFIELD_BIT(idx == -1 ? 3 : idx));
}
@@ -715,6 +717,8 @@ get_gfx11_true16_mask(aco_opcode op)
case aco_opcode::v_and_b16:
case aco_opcode::v_or_b16:
case aco_opcode::v_xor_b16: return 0x3 | 0x8;
+ case aco_opcode::v_cvt_pk_f32_fp8:
+ case aco_opcode::v_cvt_pk_f32_bf8:
case aco_opcode::v_cvt_f32_f16:
case aco_opcode::v_cvt_i32_i16:
case aco_opcode::v_cvt_u32_u16: return 0x1;
diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp
index 23c6197ff71..72f931e78c4 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -467,7 +467,8 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
- instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
+ instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4 &&
+ instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_fp8;
}
/* only covers special cases */
@@ -528,6 +529,7 @@ alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
case aco_opcode::v_dot2_bf16_bf16: /* TODO */
case aco_opcode::v_wmma_f32_16x16x16_f16:
+ case aco_opcode::v_wmma_f32_16x16x16_fp8_fp8:
case aco_opcode::v_wmma_f32_16x16x16_bf16:
case aco_opcode::v_wmma_f16_16x16x16_f16:
case aco_opcode::v_wmma_bf16_16x16x16_bf16:
diff --git a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp
index cc15635ddc0..3ffee0e4444 100644
--- a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp
+++ b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp
@@ -412,6 +412,8 @@ init_context(isel_context* ctx, nir_shader* shader)
regclasses[alu_instr->src[0].src.ssa->index].type() == RegType::vgpr)
type = RegType::vgpr;
break;
+ case nir_op_e4m3fn2f:
+ case nir_op_f2e4m3fn:
case nir_op_fmulz:
case nir_op_ffmaz:
case nir_op_f2f64:
diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp
index aaeef4cb619..15bce4d7607 100644
--- a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp
+++ b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp
@@ -101,7 +101,7 @@ get_alu_src(struct isel_context* ctx, nir_alu_src src, unsigned size = 1)
elems[i] = emit_extract_vector(ctx, vec, src.swizzle[i], elem_rc);
vec_instr->operands[i] = Operand{elems[i]};
}
- Temp dst = ctx->program->allocateTmp(RegClass(vec.type(), elem_size * size / 4));
+ Temp dst = ctx->program->allocateTmp(RegClass::get(vec.type(), elem_size * size));
vec_instr->definitions[0] = Definition(dst);
ctx->block->instructions.emplace_back(std::move(vec_instr));
ctx->allocated_vec.emplace(dst.id(), elems);
@@ -2474,6 +2474,35 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src);
break;
}
+ case nir_op_e4m3fn2f: {
+ if (instr->def.num_components == 2) {
+ Temp src = get_alu_src(ctx, instr->src[0], 2);
+ bld.vop1(aco_opcode::v_cvt_pk_f32_fp8, Definition(dst), src);
+ emit_split_vector(ctx, dst, 2);
+ } else {
+ Temp src = get_alu_src(ctx, instr->src[0]);
+ assert(instr->def.num_components == 1);
+ bld.vop1(aco_opcode::v_cvt_f32_fp8, Definition(dst), src);
+ }
+ break;
+ }
+ case nir_op_f2e4m3fn: {
+ Operand src0, src1;
+ if (instr->def.num_components == 2) {
+ Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa);
+ RegClass rc = RegClass(src.regClass().type(), 1);
+ src0 = Operand(emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc));
+ src1 = Operand(emit_extract_vector(ctx, src, instr->src[0].swizzle[1], rc));
+ } else {
+ assert(instr->def.num_components == 1);
+ src0 = Operand(get_alu_src(ctx, instr->src[0]));
+ src1 = Operand::c32(0);
+ }
+ bld.vop3(aco_opcode::v_cvt_pk_fp8_f32, Definition(dst), src0, src1);
+ if (instr->def.num_components == 2)
+ emit_split_vector(ctx, dst, 2);
+ break;
+ }
case nir_op_i2f16: {
Temp src = get_alu_src(ctx, instr->src[0]);
const unsigned input_size = instr->src[0].src.ssa->bit_size;
diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp
index f4ee6af3f83..6dc808827cc 100644
--- a/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp
+++ b/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp
@@ -3718,7 +3718,7 @@ get_replicated_constant(nir_def* def, unsigned stride, uint32_t* constant)
return true;
}
-void
+static void
visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
{
aco_opcode opcode = aco_opcode::num_opcodes;
@@ -3748,6 +3748,7 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
neg_lo[0] = type_a == GLSL_TYPE_INT8;
neg_lo[1] = type_b == GLSL_TYPE_INT8;
break;
+ case GLSL_TYPE_FLOAT_E4M3FN: opcode = aco_opcode::v_wmma_f32_16x16x16_fp8_fp8; break;
}
default: unreachable("invalid cmat_muladd_amd type");
}
diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
index 9e12b0964da..f58a2769e9f 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
@@ -166,13 +166,15 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
if (params->gfx_level >= GFX12) {
base_row = nir_udiv_imm(b, local_idx, 16);
- if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) {
+ if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 64) {
/* Switch rows from lanes 16..31 to 32..47, offset right shift by -2
* to get implicit * 4.
*/
base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2);
+ } else if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 32) {
+ base_row = nir_imul_imm(b, base_row, 8);
} else {
- base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4);
+ base_row = nir_imul_imm(b, base_row, 4);
}
} else {
base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0);
@@ -181,6 +183,24 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
return base_row;
}
+static unsigned
+radv_get_row_iter(struct glsl_cmat_description desc, const lower_cmat_params *params, unsigned i)
+{
+ if (params->gfx_level >= GFX12) {
+ /* 8bit and ACC are indexed normally, 16bit A/B is weird. */
+ if (desc.use != GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 && radv_nir_cmat_bits(desc) >= 16)
+ return i + (i & 4);
+ else
+ return i;
+ } else {
+ if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
+ return i;
+ else
+ return i * params->wave_size / 16;
+ }
+}
+
+
static nir_def *
convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type)
{
@@ -193,6 +213,12 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en
} else if (dst_type == GLSL_TYPE_BFLOAT16) {
src = convert_base_type(b, src, src_type, GLSL_TYPE_FLOAT);
return nir_f2bf(b, src);
+ } else if (src_type == GLSL_TYPE_FLOAT_E4M3FN) {
+ src = nir_e4m3fn2f(b, src);
+ return convert_base_type(b, src, GLSL_TYPE_FLOAT, dst_type);
+ } else if (dst_type == GLSL_TYPE_FLOAT_E4M3FN) {
+ src = convert_base_type(b, src, src_type, GLSL_TYPE_FLOAT);
+ return nir_f2e4m3fn(b, src);
}
nir_op op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_type),
@@ -201,6 +227,44 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en
return nir_build_alu1(b, op, src);
}
+static nir_def *
+radv_swizzle_gfx12_8bit_mat(nir_builder *b, nir_def *src, unsigned wave_size)
+{
+ assert(src->bit_size == 8);
+
+ src = nir_extract_bits(b, &src, 1, 0, src->num_components / 4, 32);
+
+ nir_def *res;
+
+ if (wave_size == 64) {
+ assert(src->num_components == 1);
+
+ nir_def *swapped = nir_rotate(b, src, nir_imm_int(b, 32), .cluster_size = 64);
+ swapped = nir_rotate(b, swapped, nir_imm_int(b, 16), .cluster_size = 32);
+
+ nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xffffffff0000ull, 64));
+
+ res = nir_bcsel(b, cond, swapped, src);
+ } else {
+ assert(src->num_components == 2);
+
+ nir_def *src0 = nir_channel(b, src, 0);
+ nir_def *src1 = nir_channel(b, src, 1);
+
+ nir_def *swapped0 = nir_rotate(b, src0, nir_imm_int(b, 16), .cluster_size = 32);
+ nir_def *swapped1 = nir_rotate(b, src1, nir_imm_int(b, 16), .cluster_size = 32);
+
+ nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xffff0000, 32));
+
+ nir_def *res0 = nir_bcsel(b, cond, swapped1, src0);
+ nir_def *res1 = nir_bcsel(b, cond, swapped0, src1);
+
+ res = nir_vec2(b, res0, res1);
+ }
+
+ return nir_extract_bits(b, &res, 1, 0, res->num_components * 4, 8);
+}
+
bool
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
{
@@ -311,7 +375,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc, &params);
- unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
@@ -324,16 +387,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
- nir_def *row_offset;
- uint32_t row_iter;
- if (gfx_level >= GFX12) {
- row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
- } else {
- row_iter = i * lanes_per_iter / 16;
- }
+ uint32_t row_iter = radv_get_row_iter(desc, &params, i);
- row_offset = nir_iadd_imm(&b, base_row, row_iter);
+ nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
@@ -385,7 +442,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc, &params);
- unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
for (unsigned i = 0; i < length; ++i)
vars[i] = nir_channel(&b, src, i);
@@ -395,16 +451,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
- nir_def *row_offset;
- uint32_t row_iter;
- if (gfx_level >= GFX12) {
- row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
- } else {
- row_iter = i * lanes_per_iter / 16;
- }
+ uint32_t row_iter = radv_get_row_iter(desc, &params, i);
- row_offset = nir_iadd_imm(&b, base_row, row_iter);
+ nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
@@ -484,8 +534,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
src = nir_vec(&b, components, src->num_components / scale);
}
+ if (dst_desc.use != GLSL_CMAT_USE_ACCUMULATOR && gfx_level >= GFX12 &&
+ radv_nir_cmat_bits(src_desc) == 8 && radv_nir_cmat_bits(dst_desc) > 8)
+ src = radv_swizzle_gfx12_8bit_mat(&b, src, wave_size);
+
nir_def *ret = convert_base_type(&b, src, src_element_type, dst_element_type);
+ if (dst_desc.use != GLSL_CMAT_USE_ACCUMULATOR && gfx_level >= GFX12 &&
+ radv_nir_cmat_bits(dst_desc) == 8 && radv_nir_cmat_bits(src_desc) > 8)
+ ret = radv_swizzle_gfx12_8bit_mat(&b, ret, wave_size);
+
if (dst_mul > src_mul) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
unsigned scale = dst_mul / src_mul;
diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c
index e042223f8b2..fa8bdd49cb9 100644
--- a/src/amd/vulkan/radv_pipeline.c
+++ b/src/amd/vulkan/radv_pipeline.c
@@ -265,6 +265,10 @@ opt_vectorize_callback(const nir_instr *instr, const void *_)
return 1;
const nir_alu_instr *alu = nir_instr_as_alu(instr);
+
+ if (alu->op == nir_op_f2e4m3fn || alu->op == nir_op_e4m3fn2f)
+ return 2;
+
const unsigned bit_size = alu->def.bit_size;
if (bit_size != 16)
return 1;
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index dcbc559a1b7..d9deff645b8 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -82,6 +82,10 @@ vectorize_vec2_16bit(const nir_instr *instr, const void *_)
return 0;
const nir_alu_instr *alu = nir_instr_as_alu(instr);
+
+ if (alu->op == nir_op_f2e4m3fn || alu->op == nir_op_e4m3fn2f)
+ return 2;
+
const unsigned bit_size = alu->def.bit_size;
if (bit_size == 16)
return 2;
@@ -2140,7 +2144,8 @@ radv_postprocess_binary_config(struct radv_device *device, struct radv_shader_bi
case MESA_SHADER_ANY_HIT:
case MESA_SHADER_COMPUTE:
case MESA_SHADER_TASK:
- config->rsrc1 |= S_00B848_MEM_ORDERED(radv_mem_ordered(pdev)) | S_00B848_WGP_MODE(wgp_mode);
+ config->rsrc1 |= S_00B848_MEM_ORDERED(radv_mem_ordered(pdev)) | S_00B848_WGP_MODE(wgp_mode) |
+ S_00B848_FP16_OVFL(info->uses_f2e4m3fn);
config->rsrc2 |= S_00B84C_TGID_X_EN(info->cs.uses_block_id[0]) | S_00B84C_TGID_Y_EN(info->cs.uses_block_id[1]) |
S_00B84C_TGID_Z_EN(info->cs.uses_block_id[2]) |
S_00B84C_TIDIG_COMP_CNT(info->cs.uses_thread_id[2] ? 2
diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c
index 0a541bc6e75..4495dae8a88 100644
--- a/src/amd/vulkan/radv_shader_info.c
+++ b/src/amd/vulkan/radv_shader_info.c
@@ -345,6 +345,18 @@ gather_tex_info(const nir_shader *nir, const nir_tex_instr *instr, struct radv_s
}
}
+static void
+gather_alu_info(const nir_shader *nir, const nir_alu_instr *instr, struct radv_shader_info *info)
+{
+ switch (instr->op) {
+ case nir_op_f2e4m3fn:
+ info->uses_f2e4m3fn = true;
+ break;
+ default:
+ break;
+ }
+}
+
static void
gather_info_block(const nir_shader *nir, const nir_block *block, struct radv_shader_info *info,
const struct radv_graphics_state_key *gfx_state, const struct radv_shader_stage_key *stage_key,
@@ -358,6 +370,8 @@ gather_info_block(const nir_shader *nir, const nir_block *block, struct radv_sha
case nir_instr_type_tex:
gather_tex_info(nir, nir_instr_as_tex(instr), info);
break;
+ case nir_instr_type_alu:
+ gather_alu_info(nir, nir_instr_as_alu(instr), info);
default:
break;
}
@@ -1845,6 +1859,7 @@ radv_nir_shader_info_merge(const struct radv_shader_stage *src, struct radv_shad
dst_info->desc_set_used_mask |= src_info->desc_set_used_mask;
dst_info->uses_view_index |= src_info->uses_view_index;
dst_info->uses_prim_id |= src_info->uses_prim_id;
+ dst_info->uses_f2e4m3fn |= src_info->uses_f2e4m3fn;
dst_info->inline_push_constant_mask |= src_info->inline_push_constant_mask;
/* Only inline all push constants if both allows it. */
diff --git a/src/amd/vulkan/radv_shader_info.h b/src/amd/vulkan/radv_shader_info.h
index eb70b764ab5..c6afbc216a0 100644
--- a/src/amd/vulkan/radv_shader_info.h
+++ b/src/amd/vulkan/radv_shader_info.h
@@ -89,6 +89,7 @@ struct radv_shader_info {
bool uses_view_index;
bool uses_invocation_id;
bool uses_prim_id;
+ bool uses_f2e4m3fn;
uint8_t wave_size;
uint8_t ballot_bit_size;
struct radv_userdata_locations user_sgprs_locs;
diff --git a/src/compiler/builtin_types.py b/src/compiler/builtin_types.py
index 99bb1709676..ec48be2bac0 100644
--- a/src/compiler/builtin_types.py
+++ b/src/compiler/builtin_types.py
@@ -62,6 +62,7 @@ vector_type("int8_t", "i8vec", "GLSL_TYPE_INT8", "GL_INT8", "_NV")
vector_type("uint8_t", "u8vec", "GLSL_TYPE_UINT8", "GL_UNSIGNED_INT8", "_NV")
vector_type("bfloat16_t", "bf16vec", "GLSL_TYPE_BFLOAT16", None)
+vector_type("e4m3fn_t", "e4m3fnvec", "GLSL_TYPE_FLOAT_E4M3FN", None)
simple_type("mat2", "GL_FLOAT_MAT2", "GLSL_TYPE_FLOAT", 2, 2)
simple_type("mat3", "GL_FLOAT_MAT3", "GLSL_TYPE_FLOAT", 3, 3)
diff --git a/src/compiler/glsl_types.c b/src/compiler/glsl_types.c
index 10fcd786fac..280f971f244 100644
--- a/src/compiler/glsl_types.c
+++ b/src/compiler/glsl_types.c
@@ -349,6 +349,8 @@ glsl_get_base_glsl_type(const glsl_type *t)
return &glsl_type_builtin_double;
case GLSL_TYPE_BFLOAT16:
return &glsl_type_builtin_bfloat16_t;
+ case GLSL_TYPE_FLOAT_E4M3FN:
+ return &glsl_type_builtin_e4m3fn_t;
case GLSL_TYPE_BOOL:
return &glsl_type_builtin_bool;
case GLSL_TYPE_UINT64:
@@ -387,6 +389,7 @@ glsl_get_bare_type(const glsl_type *t)
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
@@ -597,6 +600,7 @@ glsl_ ## vname ## _type (unsigned components) \
VECN(components, float, vec)
VECN(components, float16_t, f16vec)
VECN(components, bfloat16_t, bf16vec)
+VECN(components, e4m3fn_t, e4m3fnvec)
VECN(components, double, dvec)
VECN(components, int, ivec)
VECN(components, uint, uvec)
@@ -647,6 +651,8 @@ glsl_simple_explicit_type(unsigned base_type, unsigned rows, unsigned columns,
return glsl_f16vec_type(rows);
case GLSL_TYPE_BFLOAT16:
return glsl_bf16vec_type(rows);
+ case GLSL_TYPE_FLOAT_E4M3FN:
+ return glsl_e4m3fnvec_type(rows);
case GLSL_TYPE_DOUBLE:
return glsl_dvec_type(rows);
case GLSL_TYPE_BOOL:
@@ -1749,6 +1755,7 @@ glsl_get_component_slots(const glsl_type *t)
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_BOOL:
return glsl_get_components(t);
@@ -1802,6 +1809,7 @@ glsl_get_component_slots_aligned(const glsl_type *t, unsigned offset)
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_BOOL:
return glsl_get_components(t);
@@ -2889,6 +2897,7 @@ glsl_count_vec4_slots(const glsl_type *t, bool is_gl_vertex_input, bool is_bindl
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_BOOL:
return t->matrix_columns;
case GLSL_TYPE_DOUBLE:
@@ -3094,6 +3103,7 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8:
@@ -3743,6 +3753,7 @@ glsl_get_natural_size_align_bytes(const glsl_type *type,
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
@@ -3803,6 +3814,7 @@ glsl_get_word_size_align_bytes(const glsl_type *type,
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
@@ -3863,6 +3875,7 @@ glsl_get_vec4_size_align_bytes(const glsl_type *type,
case GLSL_TYPE_INT16:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_UINT:
case GLSL_TYPE_INT:
case GLSL_TYPE_FLOAT:
diff --git a/src/compiler/glsl_types.h b/src/compiler/glsl_types.h
index 4afba690abf..2defbf3136a 100644
--- a/src/compiler/glsl_types.h
+++ b/src/compiler/glsl_types.h
@@ -64,6 +64,7 @@ enum glsl_base_type {
GLSL_TYPE_FLOAT,
GLSL_TYPE_FLOAT16,
GLSL_TYPE_BFLOAT16,
+ GLSL_TYPE_FLOAT_E4M3FN,
GLSL_TYPE_DOUBLE,
GLSL_TYPE_UINT8,
GLSL_TYPE_INT8,
@@ -107,6 +108,7 @@ static unsigned glsl_base_type_bit_size(enum glsl_base_type type)
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8:
+ case GLSL_TYPE_FLOAT_E4M3FN:
return 8;
case GLSL_TYPE_DOUBLE:
@@ -176,6 +178,7 @@ glsl_base_type_get_bit_size(const enum glsl_base_type base_type)
case GLSL_TYPE_UINT8:
case GLSL_TYPE_INT8:
+ case GLSL_TYPE_FLOAT_E4M3FN:
return 8;
case GLSL_TYPE_DOUBLE:
@@ -630,6 +633,12 @@ glsl_type_is_bfloat_16(const glsl_type *t)
return t->base_type == GLSL_TYPE_BFLOAT16;
}
+static inline bool
+glsl_type_is_e4m3fn(const glsl_type *t)
+{
+ return t->base_type == GLSL_TYPE_FLOAT_E4M3FN;
+}
+
static inline bool
glsl_type_is_int_16_32_64(const glsl_type *t)
{
@@ -947,6 +956,7 @@ static inline const glsl_type *glsl_uint8_t_type(void) { return &glsl_type_built
static inline const glsl_type *glsl_bool_type(void) { return &glsl_type_builtin_bool; }
static inline const glsl_type *glsl_atomic_uint_type(void) { return &glsl_type_builtin_atomic_uint; }
static inline const glsl_type *glsl_bfloat16_t_type(void) { return &glsl_type_builtin_bfloat16_t; }
+static inline const glsl_type *glsl_e4m3fn_t_type(void) { return &glsl_type_builtin_e4m3fn_t; }
static inline const glsl_type *
glsl_floatN_t_type(unsigned bit_size)
@@ -999,6 +1009,7 @@ glsl_uintN_t_type(unsigned bit_size)
const glsl_type *glsl_vec_type(unsigned components);
const glsl_type *glsl_f16vec_type(unsigned components);
const glsl_type *glsl_bf16vec_type(unsigned components);
+const glsl_type *glsl_e4m3fnvec_type(unsigned components);
const glsl_type *glsl_dvec_type(unsigned components);
const glsl_type *glsl_ivec_type(unsigned components);
const glsl_type *glsl_uvec_type(unsigned components);
diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c
index f3a32256a63..18cec968501 100644
--- a/src/compiler/nir/nir.c
+++ b/src/compiler/nir/nir.c
@@ -2904,6 +2904,7 @@ nir_get_nir_type_for_glsl_base_type(enum glsl_base_type base_type)
case GLSL_TYPE_FLOAT: return nir_type_float32;
case GLSL_TYPE_FLOAT16: return nir_type_float16;
case GLSL_TYPE_BFLOAT16: return nir_type_uint16;
+ case GLSL_TYPE_FLOAT_E4M3FN: return nir_type_uint8;
case GLSL_TYPE_DOUBLE: return nir_type_float64;
/* clang-format on */
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index c2995cccf09..def4b6f284c 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -1765,3 +1765,7 @@ opcode("bfdot2_bfadd", 1, tint16, [2, 2, 1], [tint16, tint16, tint16],
dst.x = _mesa_float_to_bfloat16_bits_rte(acc);
""")
+
+
+unop_numeric_convert("e4m3fn2f", tfloat32, tuint8, "0") # TODO constant fold
+unop_numeric_convert("f2e4m3fn", tuint8, tfloat32, "0") # TODO constant fold
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 05ccce54db6..7b3275fd493 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -1890,10 +1890,14 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
int32_t encoding = count > 3 ? w[3] : -1;
switch (encoding) {
case -1:
- /* No encoding specified, it is a regular FP. */
- vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64,
+ if (bit_size == 8) {
+ val->type->type = glsl_e4m3fn_t_type();
+ } else {
+ /* No encoding specified, it is a regular FP. */
+ vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64,
"Invalid float bit size: %u", bit_size);
- val->type->type = glsl_floatN_t_type(bit_size);
+ val->type->type = glsl_floatN_t_type(bit_size);
+ }
break;
case SpvFPEncodingBFloat16KHR:
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index a528b2e1b12..a89d9b2ad9e 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -697,6 +697,23 @@ vtn_handle_convert(struct vtn_builder *b, SpvOp opcode,
return nir_f2bf(&b->nb, src_as_float);
}
+ if (glsl_type_is_e4m3fn(glsl_src_type)) {
+ nir_def *src_as_float = nir_e4m3fn2f(&b->nb, src);
+ if (glsl_type_is_float(glsl_dest_type))
+ return src_as_float;
+ return vtn_handle_convert(b, opcode, dest_val, glsl_dest_type,
+ glsl_float_type(), src_as_float);
+
+ } else if (glsl_type_is_e4m3fn(glsl_dest_type)) {
+ nir_def *src_as_float;
+ if (glsl_type_is_float(glsl_src_type))
+ src_as_float = src;
+ else
+ src_as_float = vtn_handle_convert(b, opcode, dest_val, glsl_float_type(),
+ glsl_src_type, src);
+ return nir_f2e4m3fn(&b->nb, src_as_float);
+ }
+
/* Use bit_size from NIR source instead of from the original src type,
* to account for mediump_16bit. See vtn_handle_alu() for details.
*/
diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c
index baf359c9f9f..c474ea9a81c 100644
--- a/src/compiler/spirv/vtn_variables.c
+++ b/src/compiler/spirv/vtn_variables.c
@@ -716,6 +716,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
case GLSL_TYPE_INT64:
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_BFLOAT16:
case GLSL_TYPE_BOOL:
case GLSL_TYPE_DOUBLE:
@@ -811,6 +812,7 @@ _vtn_variable_copy(struct vtn_builder *b, struct vtn_pointer *dest,
case GLSL_TYPE_FLOAT:
case GLSL_TYPE_FLOAT16:
case GLSL_TYPE_BFLOAT16:
+ case GLSL_TYPE_FLOAT_E4M3FN:
case GLSL_TYPE_DOUBLE:
case GLSL_TYPE_BOOL:
/* At this point, we have a scalar, vector, or matrix so we know that
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment