Last active
May 22, 2025 21:08
-
-
Save Hadrianneue/6815b5cccc6280cea6ef67ca73b01a9a to your computer and use it in GitHub Desktop.
mesa fp8 hack for fsr4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| diff --git a/src/amd/compiler/aco_interface.cpp b/src/amd/compiler/aco_interface.cpp | |
| index a08f6560578..055ba4c1645 100644 | |
| --- a/src/amd/compiler/aco_interface.cpp | |
| +++ b/src/amd/compiler/aco_interface.cpp | |
| @@ -482,6 +482,8 @@ aco_nir_op_supports_packed_math_16bit(const nir_alu_instr* alu) | |
| return (shader->options->force_f2f16_rtz && !nir_is_rounding_mode_rtne(execution_mode, 16)) || | |
| nir_is_rounding_mode_rtz(execution_mode, 16); | |
| } | |
| + case nir_op_f2e4m3fn: | |
| + case nir_op_e4m3fn2f: | |
| case nir_op_fadd: | |
| case nir_op_fsub: | |
| case nir_op_fmul: | |
| diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp | |
| index 8d8bc7fe746..8d1466a6dbc 100644 | |
| --- a/src/amd/compiler/aco_ir.cpp | |
| +++ b/src/amd/compiler/aco_ir.cpp | |
| @@ -584,6 +584,8 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx) | |
| case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2; | |
| case aco_opcode::v_interp_p2_f16_f32_inreg: | |
| case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0; | |
| + case aco_opcode::v_cvt_pk_fp8_f32: | |
| + case aco_opcode::v_cvt_pk_bf8_f32: return idx == -1; | |
| default: | |
| return gfx_level >= GFX11 && (get_gfx11_true16_mask(op) & BITFIELD_BIT(idx == -1 ? 3 : idx)); | |
| } | |
| @@ -715,6 +717,8 @@ get_gfx11_true16_mask(aco_opcode op) | |
| case aco_opcode::v_and_b16: | |
| case aco_opcode::v_or_b16: | |
| case aco_opcode::v_xor_b16: return 0x3 | 0x8; | |
| + case aco_opcode::v_cvt_pk_f32_fp8: | |
| + case aco_opcode::v_cvt_pk_f32_bf8: | |
| case aco_opcode::v_cvt_f32_f16: | |
| case aco_opcode::v_cvt_i32_i16: | |
| case aco_opcode::v_cvt_u32_u16: return 0x1; | |
| diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp | |
| index 23c6197ff71..72f931e78c4 100644 | |
| --- a/src/amd/compiler/aco_optimizer.cpp | |
| +++ b/src/amd/compiler/aco_optimizer.cpp | |
| @@ -467,7 +467,8 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr) | |
| instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 && | |
| instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 && | |
| instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 && | |
| - instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4; | |
| + instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4 && | |
| + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_fp8_fp8; | |
| } | |
| /* only covers special cases */ | |
| @@ -528,6 +529,7 @@ alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand) | |
| case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: | |
| case aco_opcode::v_dot2_bf16_bf16: /* TODO */ | |
| case aco_opcode::v_wmma_f32_16x16x16_f16: | |
| + case aco_opcode::v_wmma_f32_16x16x16_fp8_fp8: | |
| case aco_opcode::v_wmma_f32_16x16x16_bf16: | |
| case aco_opcode::v_wmma_f16_16x16x16_f16: | |
| case aco_opcode::v_wmma_bf16_16x16x16_bf16: | |
| diff --git a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp | |
| index cc15635ddc0..3ffee0e4444 100644 | |
| --- a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp | |
| +++ b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp | |
| @@ -412,6 +412,8 @@ init_context(isel_context* ctx, nir_shader* shader) | |
| regclasses[alu_instr->src[0].src.ssa->index].type() == RegType::vgpr) | |
| type = RegType::vgpr; | |
| break; | |
| + case nir_op_e4m3fn2f: | |
| + case nir_op_f2e4m3fn: | |
| case nir_op_fmulz: | |
| case nir_op_ffmaz: | |
| case nir_op_f2f64: | |
| diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp | |
| index aaeef4cb619..15bce4d7607 100644 | |
| --- a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp | |
| +++ b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp | |
| @@ -101,7 +101,7 @@ get_alu_src(struct isel_context* ctx, nir_alu_src src, unsigned size = 1) | |
| elems[i] = emit_extract_vector(ctx, vec, src.swizzle[i], elem_rc); | |
| vec_instr->operands[i] = Operand{elems[i]}; | |
| } | |
| - Temp dst = ctx->program->allocateTmp(RegClass(vec.type(), elem_size * size / 4)); | |
| + Temp dst = ctx->program->allocateTmp(RegClass::get(vec.type(), elem_size * size)); | |
| vec_instr->definitions[0] = Definition(dst); | |
| ctx->block->instructions.emplace_back(std::move(vec_instr)); | |
| ctx->allocated_vec.emplace(dst.id(), elems); | |
| @@ -2474,6 +2474,35 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) | |
| bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src); | |
| break; | |
| } | |
| + case nir_op_e4m3fn2f: { | |
| + if (instr->def.num_components == 2) { | |
| + Temp src = get_alu_src(ctx, instr->src[0], 2); | |
| + bld.vop1(aco_opcode::v_cvt_pk_f32_fp8, Definition(dst), src); | |
| + emit_split_vector(ctx, dst, 2); | |
| + } else { | |
| + Temp src = get_alu_src(ctx, instr->src[0]); | |
| + assert(instr->def.num_components == 1); | |
| + bld.vop1(aco_opcode::v_cvt_f32_fp8, Definition(dst), src); | |
| + } | |
| + break; | |
| + } | |
| + case nir_op_f2e4m3fn: { | |
| + Operand src0, src1; | |
| + if (instr->def.num_components == 2) { | |
| + Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa); | |
| + RegClass rc = RegClass(src.regClass().type(), 1); | |
| + src0 = Operand(emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc)); | |
| + src1 = Operand(emit_extract_vector(ctx, src, instr->src[0].swizzle[1], rc)); | |
| + } else { | |
| + assert(instr->def.num_components == 1); | |
| + src0 = Operand(get_alu_src(ctx, instr->src[0])); | |
| + src1 = Operand::c32(0); | |
| + } | |
| + bld.vop3(aco_opcode::v_cvt_pk_fp8_f32, Definition(dst), src0, src1); | |
| + if (instr->def.num_components == 2) | |
| + emit_split_vector(ctx, dst, 2); | |
| + break; | |
| + } | |
| case nir_op_i2f16: { | |
| Temp src = get_alu_src(ctx, instr->src[0]); | |
| const unsigned input_size = instr->src[0].src.ssa->bit_size; | |
| diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp | |
| index f4ee6af3f83..6dc808827cc 100644 | |
| --- a/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp | |
| +++ b/src/amd/compiler/instruction_selection/aco_select_nir_intrinsics.cpp | |
| @@ -3718,7 +3718,7 @@ get_replicated_constant(nir_def* def, unsigned stride, uint32_t* constant) | |
| return true; | |
| } | |
| -void | |
| +static void | |
| visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) | |
| { | |
| aco_opcode opcode = aco_opcode::num_opcodes; | |
| @@ -3748,6 +3748,7 @@ visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) | |
| neg_lo[0] = type_a == GLSL_TYPE_INT8; | |
| neg_lo[1] = type_b == GLSL_TYPE_INT8; | |
| break; | |
| + case GLSL_TYPE_FLOAT_E4M3FN: opcode = aco_opcode::v_wmma_f32_16x16x16_fp8_fp8; break; | |
| } | |
| default: unreachable("invalid cmat_muladd_amd type"); | |
| } | |
| diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c | |
| index 9e12b0964da..f58a2769e9f 100644 | |
| --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c | |
| +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c | |
| @@ -166,13 +166,15 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower | |
| if (params->gfx_level >= GFX12) { | |
| base_row = nir_udiv_imm(b, local_idx, 16); | |
| - if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) { | |
| + if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 64) { | |
| /* Switch rows from lanes 16..31 to 32..47, offset right shift by -2 | |
| * to get implicit * 4. | |
| */ | |
| base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2); | |
| + } else if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 32) { | |
| + base_row = nir_imul_imm(b, base_row, 8); | |
| } else { | |
| - base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4); | |
| + base_row = nir_imul_imm(b, base_row, 4); | |
| } | |
| } else { | |
| base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0); | |
| @@ -181,6 +183,24 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower | |
| return base_row; | |
| } | |
| +static unsigned | |
| +radv_get_row_iter(struct glsl_cmat_description desc, const lower_cmat_params *params, unsigned i) | |
| +{ | |
| + if (params->gfx_level >= GFX12) { | |
| + /* 8bit and ACC are indexed normally, 16bit A/B is weird. */ | |
| + if (desc.use != GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 && radv_nir_cmat_bits(desc) >= 16) | |
| + return i + (i & 4); | |
| + else | |
| + return i; | |
| + } else { | |
| + if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) | |
| + return i; | |
| + else | |
| + return i * params->wave_size / 16; | |
| + } | |
| +} | |
| + | |
| + | |
| static nir_def * | |
| convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type) | |
| { | |
| @@ -193,6 +213,12 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en | |
| } else if (dst_type == GLSL_TYPE_BFLOAT16) { | |
| src = convert_base_type(b, src, src_type, GLSL_TYPE_FLOAT); | |
| return nir_f2bf(b, src); | |
| + } else if (src_type == GLSL_TYPE_FLOAT_E4M3FN) { | |
| + src = nir_e4m3fn2f(b, src); | |
| + return convert_base_type(b, src, GLSL_TYPE_FLOAT, dst_type); | |
| + } else if (dst_type == GLSL_TYPE_FLOAT_E4M3FN) { | |
| + src = convert_base_type(b, src, src_type, GLSL_TYPE_FLOAT); | |
| + return nir_f2e4m3fn(b, src); | |
| } | |
| nir_op op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_type), | |
| @@ -201,6 +227,44 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en | |
| return nir_build_alu1(b, op, src); | |
| } | |
| +static nir_def * | |
| +radv_swizzle_gfx12_8bit_mat(nir_builder *b, nir_def *src, unsigned wave_size) | |
| +{ | |
| + assert(src->bit_size == 8); | |
| + | |
| + src = nir_extract_bits(b, &src, 1, 0, src->num_components / 4, 32); | |
| + | |
| + nir_def *res; | |
| + | |
| + if (wave_size == 64) { | |
| + assert(src->num_components == 1); | |
| + | |
| + nir_def *swapped = nir_rotate(b, src, nir_imm_int(b, 32), .cluster_size = 64); | |
| + swapped = nir_rotate(b, swapped, nir_imm_int(b, 16), .cluster_size = 32); | |
| + | |
| + nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xffffffff0000ull, 64)); | |
| + | |
| + res = nir_bcsel(b, cond, swapped, src); | |
| + } else { | |
| + assert(src->num_components == 2); | |
| + | |
| + nir_def *src0 = nir_channel(b, src, 0); | |
| + nir_def *src1 = nir_channel(b, src, 1); | |
| + | |
| + nir_def *swapped0 = nir_rotate(b, src0, nir_imm_int(b, 16), .cluster_size = 32); | |
| + nir_def *swapped1 = nir_rotate(b, src1, nir_imm_int(b, 16), .cluster_size = 32); | |
| + | |
| + nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xffff0000, 32)); | |
| + | |
| + nir_def *res0 = nir_bcsel(b, cond, swapped1, src0); | |
| + nir_def *res1 = nir_bcsel(b, cond, swapped0, src1); | |
| + | |
| + res = nir_vec2(b, res0, res1); | |
| + } | |
| + | |
| + return nir_extract_bits(b, &res, 1, 0, res->num_components * 4, 8); | |
| +} | |
| + | |
| bool | |
| radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size) | |
| { | |
| @@ -311,7 +375,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev | |
| unsigned length = radv_nir_cmat_length(desc, ¶ms); | |
| unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms); | |
| - unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; | |
| nir_def *vars[16]; | |
| if (mul > 1) { | |
| for (unsigned i = 0; i < length; ++i) | |
| @@ -324,16 +387,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev | |
| for (unsigned i = 0; i < length / mul; ++i) { | |
| nir_def *col_offset = inner_idx; | |
| - nir_def *row_offset; | |
| - uint32_t row_iter; | |
| - if (gfx_level >= GFX12) { | |
| - row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i; | |
| - } else { | |
| - row_iter = i * lanes_per_iter / 16; | |
| - } | |
| + uint32_t row_iter = radv_get_row_iter(desc, ¶ms, i); | |
| - row_offset = nir_iadd_imm(&b, base_row, row_iter); | |
| + nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter); | |
| if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { | |
| nir_def *tmp = col_offset; | |
| @@ -385,7 +442,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev | |
| unsigned length = radv_nir_cmat_length(desc, ¶ms); | |
| unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms); | |
| - unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; | |
| nir_def *vars[16]; | |
| for (unsigned i = 0; i < length; ++i) | |
| vars[i] = nir_channel(&b, src, i); | |
| @@ -395,16 +451,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev | |
| for (unsigned i = 0; i < length / mul; ++i) { | |
| nir_def *col_offset = inner_idx; | |
| - nir_def *row_offset; | |
| - uint32_t row_iter; | |
| - if (gfx_level >= GFX12) { | |
| - row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i; | |
| - } else { | |
| - row_iter = i * lanes_per_iter / 16; | |
| - } | |
| + uint32_t row_iter = radv_get_row_iter(desc, ¶ms, i); | |
| - row_offset = nir_iadd_imm(&b, base_row, row_iter); | |
| + nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter); | |
| if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { | |
| nir_def *tmp = col_offset; | |
| @@ -484,8 +534,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev | |
| src = nir_vec(&b, components, src->num_components / scale); | |
| } | |
| + if (dst_desc.use != GLSL_CMAT_USE_ACCUMULATOR && gfx_level >= GFX12 && | |
| + radv_nir_cmat_bits(src_desc) == 8 && radv_nir_cmat_bits(dst_desc) > 8) | |
| + src = radv_swizzle_gfx12_8bit_mat(&b, src, wave_size); | |
| + | |
| nir_def *ret = convert_base_type(&b, src, src_element_type, dst_element_type); | |
| + if (dst_desc.use != GLSL_CMAT_USE_ACCUMULATOR && gfx_level >= GFX12 && | |
| + radv_nir_cmat_bits(dst_desc) == 8 && radv_nir_cmat_bits(src_desc) > 8) | |
| + ret = radv_swizzle_gfx12_8bit_mat(&b, ret, wave_size); | |
| + | |
| if (dst_mul > src_mul) { | |
| nir_def *components[NIR_MAX_VEC_COMPONENTS]; | |
| unsigned scale = dst_mul / src_mul; | |
| diff --git a/src/amd/vulkan/radv_pipeline.c b/src/amd/vulkan/radv_pipeline.c | |
| index e042223f8b2..fa8bdd49cb9 100644 | |
| --- a/src/amd/vulkan/radv_pipeline.c | |
| +++ b/src/amd/vulkan/radv_pipeline.c | |
| @@ -265,6 +265,10 @@ opt_vectorize_callback(const nir_instr *instr, const void *_) | |
| return 1; | |
| const nir_alu_instr *alu = nir_instr_as_alu(instr); | |
| + | |
| + if (alu->op == nir_op_f2e4m3fn || alu->op == nir_op_e4m3fn2f) | |
| + return 2; | |
| + | |
| const unsigned bit_size = alu->def.bit_size; | |
| if (bit_size != 16) | |
| return 1; | |
| diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c | |
| index dcbc559a1b7..d9deff645b8 100644 | |
| --- a/src/amd/vulkan/radv_shader.c | |
| +++ b/src/amd/vulkan/radv_shader.c | |
| @@ -82,6 +82,10 @@ vectorize_vec2_16bit(const nir_instr *instr, const void *_) | |
| return 0; | |
| const nir_alu_instr *alu = nir_instr_as_alu(instr); | |
| + | |
| + if (alu->op == nir_op_f2e4m3fn || alu->op == nir_op_e4m3fn2f) | |
| + return 2; | |
| + | |
| const unsigned bit_size = alu->def.bit_size; | |
| if (bit_size == 16) | |
| return 2; | |
| @@ -2140,7 +2144,8 @@ radv_postprocess_binary_config(struct radv_device *device, struct radv_shader_bi | |
| case MESA_SHADER_ANY_HIT: | |
| case MESA_SHADER_COMPUTE: | |
| case MESA_SHADER_TASK: | |
| - config->rsrc1 |= S_00B848_MEM_ORDERED(radv_mem_ordered(pdev)) | S_00B848_WGP_MODE(wgp_mode); | |
| + config->rsrc1 |= S_00B848_MEM_ORDERED(radv_mem_ordered(pdev)) | S_00B848_WGP_MODE(wgp_mode) | | |
| + S_00B848_FP16_OVFL(info->uses_f2e4m3fn); | |
| config->rsrc2 |= S_00B84C_TGID_X_EN(info->cs.uses_block_id[0]) | S_00B84C_TGID_Y_EN(info->cs.uses_block_id[1]) | | |
| S_00B84C_TGID_Z_EN(info->cs.uses_block_id[2]) | | |
| S_00B84C_TIDIG_COMP_CNT(info->cs.uses_thread_id[2] ? 2 | |
| diff --git a/src/amd/vulkan/radv_shader_info.c b/src/amd/vulkan/radv_shader_info.c | |
| index 0a541bc6e75..4495dae8a88 100644 | |
| --- a/src/amd/vulkan/radv_shader_info.c | |
| +++ b/src/amd/vulkan/radv_shader_info.c | |
| @@ -345,6 +345,18 @@ gather_tex_info(const nir_shader *nir, const nir_tex_instr *instr, struct radv_s | |
| } | |
| } | |
| +static void | |
| +gather_alu_info(const nir_shader *nir, const nir_alu_instr *instr, struct radv_shader_info *info) | |
| +{ | |
| + switch (instr->op) { | |
| + case nir_op_f2e4m3fn: | |
| + info->uses_f2e4m3fn = true; | |
| + break; | |
| + default: | |
| + break; | |
| + } | |
| +} | |
| + | |
| static void | |
| gather_info_block(const nir_shader *nir, const nir_block *block, struct radv_shader_info *info, | |
| const struct radv_graphics_state_key *gfx_state, const struct radv_shader_stage_key *stage_key, | |
| @@ -358,6 +370,8 @@ gather_info_block(const nir_shader *nir, const nir_block *block, struct radv_sha | |
| case nir_instr_type_tex: | |
| gather_tex_info(nir, nir_instr_as_tex(instr), info); | |
| break; | |
| + case nir_instr_type_alu: | |
| + gather_alu_info(nir, nir_instr_as_alu(instr), info); | |
| default: | |
| break; | |
| } | |
| @@ -1845,6 +1859,7 @@ radv_nir_shader_info_merge(const struct radv_shader_stage *src, struct radv_shad | |
| dst_info->desc_set_used_mask |= src_info->desc_set_used_mask; | |
| dst_info->uses_view_index |= src_info->uses_view_index; | |
| dst_info->uses_prim_id |= src_info->uses_prim_id; | |
| + dst_info->uses_f2e4m3fn |= src_info->uses_f2e4m3fn; | |
| dst_info->inline_push_constant_mask |= src_info->inline_push_constant_mask; | |
| /* Only inline all push constants if both allows it. */ | |
| diff --git a/src/amd/vulkan/radv_shader_info.h b/src/amd/vulkan/radv_shader_info.h | |
| index eb70b764ab5..c6afbc216a0 100644 | |
| --- a/src/amd/vulkan/radv_shader_info.h | |
| +++ b/src/amd/vulkan/radv_shader_info.h | |
| @@ -89,6 +89,7 @@ struct radv_shader_info { | |
| bool uses_view_index; | |
| bool uses_invocation_id; | |
| bool uses_prim_id; | |
| + bool uses_f2e4m3fn; | |
| uint8_t wave_size; | |
| uint8_t ballot_bit_size; | |
| struct radv_userdata_locations user_sgprs_locs; | |
| diff --git a/src/compiler/builtin_types.py b/src/compiler/builtin_types.py | |
| index 99bb1709676..ec48be2bac0 100644 | |
| --- a/src/compiler/builtin_types.py | |
| +++ b/src/compiler/builtin_types.py | |
| @@ -62,6 +62,7 @@ vector_type("int8_t", "i8vec", "GLSL_TYPE_INT8", "GL_INT8", "_NV") | |
| vector_type("uint8_t", "u8vec", "GLSL_TYPE_UINT8", "GL_UNSIGNED_INT8", "_NV") | |
| vector_type("bfloat16_t", "bf16vec", "GLSL_TYPE_BFLOAT16", None) | |
| +vector_type("e4m3fn_t", "e4m3fnvec", "GLSL_TYPE_FLOAT_E4M3FN", None) | |
| simple_type("mat2", "GL_FLOAT_MAT2", "GLSL_TYPE_FLOAT", 2, 2) | |
| simple_type("mat3", "GL_FLOAT_MAT3", "GLSL_TYPE_FLOAT", 3, 3) | |
| diff --git a/src/compiler/glsl_types.c b/src/compiler/glsl_types.c | |
| index 10fcd786fac..280f971f244 100644 | |
| --- a/src/compiler/glsl_types.c | |
| +++ b/src/compiler/glsl_types.c | |
| @@ -349,6 +349,8 @@ glsl_get_base_glsl_type(const glsl_type *t) | |
| return &glsl_type_builtin_double; | |
| case GLSL_TYPE_BFLOAT16: | |
| return &glsl_type_builtin_bfloat16_t; | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| + return &glsl_type_builtin_e4m3fn_t; | |
| case GLSL_TYPE_BOOL: | |
| return &glsl_type_builtin_bool; | |
| case GLSL_TYPE_UINT64: | |
| @@ -387,6 +389,7 @@ glsl_get_bare_type(const glsl_type *t) | |
| case GLSL_TYPE_INT16: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_UINT: | |
| case GLSL_TYPE_INT: | |
| case GLSL_TYPE_FLOAT: | |
| @@ -597,6 +600,7 @@ glsl_ ## vname ## _type (unsigned components) \ | |
| VECN(components, float, vec) | |
| VECN(components, float16_t, f16vec) | |
| VECN(components, bfloat16_t, bf16vec) | |
| +VECN(components, e4m3fn_t, e4m3fnvec) | |
| VECN(components, double, dvec) | |
| VECN(components, int, ivec) | |
| VECN(components, uint, uvec) | |
| @@ -647,6 +651,8 @@ glsl_simple_explicit_type(unsigned base_type, unsigned rows, unsigned columns, | |
| return glsl_f16vec_type(rows); | |
| case GLSL_TYPE_BFLOAT16: | |
| return glsl_bf16vec_type(rows); | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| + return glsl_e4m3fnvec_type(rows); | |
| case GLSL_TYPE_DOUBLE: | |
| return glsl_dvec_type(rows); | |
| case GLSL_TYPE_BOOL: | |
| @@ -1749,6 +1755,7 @@ glsl_get_component_slots(const glsl_type *t) | |
| case GLSL_TYPE_FLOAT: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_BOOL: | |
| return glsl_get_components(t); | |
| @@ -1802,6 +1809,7 @@ glsl_get_component_slots_aligned(const glsl_type *t, unsigned offset) | |
| case GLSL_TYPE_FLOAT: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_BOOL: | |
| return glsl_get_components(t); | |
| @@ -2889,6 +2897,7 @@ glsl_count_vec4_slots(const glsl_type *t, bool is_gl_vertex_input, bool is_bindl | |
| case GLSL_TYPE_FLOAT: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_BOOL: | |
| return t->matrix_columns; | |
| case GLSL_TYPE_DOUBLE: | |
| @@ -3094,6 +3103,7 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type) | |
| case GLSL_TYPE_FLOAT: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_DOUBLE: | |
| case GLSL_TYPE_UINT8: | |
| case GLSL_TYPE_INT8: | |
| @@ -3743,6 +3753,7 @@ glsl_get_natural_size_align_bytes(const glsl_type *type, | |
| case GLSL_TYPE_INT16: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_UINT: | |
| case GLSL_TYPE_INT: | |
| case GLSL_TYPE_FLOAT: | |
| @@ -3803,6 +3814,7 @@ glsl_get_word_size_align_bytes(const glsl_type *type, | |
| case GLSL_TYPE_INT16: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_UINT: | |
| case GLSL_TYPE_INT: | |
| case GLSL_TYPE_FLOAT: | |
| @@ -3863,6 +3875,7 @@ glsl_get_vec4_size_align_bytes(const glsl_type *type, | |
| case GLSL_TYPE_INT16: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_UINT: | |
| case GLSL_TYPE_INT: | |
| case GLSL_TYPE_FLOAT: | |
| diff --git a/src/compiler/glsl_types.h b/src/compiler/glsl_types.h | |
| index 4afba690abf..2defbf3136a 100644 | |
| --- a/src/compiler/glsl_types.h | |
| +++ b/src/compiler/glsl_types.h | |
| @@ -64,6 +64,7 @@ enum glsl_base_type { | |
| GLSL_TYPE_FLOAT, | |
| GLSL_TYPE_FLOAT16, | |
| GLSL_TYPE_BFLOAT16, | |
| + GLSL_TYPE_FLOAT_E4M3FN, | |
| GLSL_TYPE_DOUBLE, | |
| GLSL_TYPE_UINT8, | |
| GLSL_TYPE_INT8, | |
| @@ -107,6 +108,7 @@ static unsigned glsl_base_type_bit_size(enum glsl_base_type type) | |
| case GLSL_TYPE_UINT8: | |
| case GLSL_TYPE_INT8: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| return 8; | |
| case GLSL_TYPE_DOUBLE: | |
| @@ -176,6 +178,7 @@ glsl_base_type_get_bit_size(const enum glsl_base_type base_type) | |
| case GLSL_TYPE_UINT8: | |
| case GLSL_TYPE_INT8: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| return 8; | |
| case GLSL_TYPE_DOUBLE: | |
| @@ -630,6 +633,12 @@ glsl_type_is_bfloat_16(const glsl_type *t) | |
| return t->base_type == GLSL_TYPE_BFLOAT16; | |
| } | |
| +static inline bool | |
| +glsl_type_is_e4m3fn(const glsl_type *t) | |
| +{ | |
| + return t->base_type == GLSL_TYPE_FLOAT_E4M3FN; | |
| +} | |
| + | |
| static inline bool | |
| glsl_type_is_int_16_32_64(const glsl_type *t) | |
| { | |
| @@ -947,6 +956,7 @@ static inline const glsl_type *glsl_uint8_t_type(void) { return &glsl_type_built | |
| static inline const glsl_type *glsl_bool_type(void) { return &glsl_type_builtin_bool; } | |
| static inline const glsl_type *glsl_atomic_uint_type(void) { return &glsl_type_builtin_atomic_uint; } | |
| static inline const glsl_type *glsl_bfloat16_t_type(void) { return &glsl_type_builtin_bfloat16_t; } | |
| +static inline const glsl_type *glsl_e4m3fn_t_type(void) { return &glsl_type_builtin_e4m3fn_t; } | |
| static inline const glsl_type * | |
| glsl_floatN_t_type(unsigned bit_size) | |
| @@ -999,6 +1009,7 @@ glsl_uintN_t_type(unsigned bit_size) | |
| const glsl_type *glsl_vec_type(unsigned components); | |
| const glsl_type *glsl_f16vec_type(unsigned components); | |
| const glsl_type *glsl_bf16vec_type(unsigned components); | |
| +const glsl_type *glsl_e4m3fnvec_type(unsigned components); | |
| const glsl_type *glsl_dvec_type(unsigned components); | |
| const glsl_type *glsl_ivec_type(unsigned components); | |
| const glsl_type *glsl_uvec_type(unsigned components); | |
| diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c | |
| index f3a32256a63..18cec968501 100644 | |
| --- a/src/compiler/nir/nir.c | |
| +++ b/src/compiler/nir/nir.c | |
| @@ -2904,6 +2904,7 @@ nir_get_nir_type_for_glsl_base_type(enum glsl_base_type base_type) | |
| case GLSL_TYPE_FLOAT: return nir_type_float32; | |
| case GLSL_TYPE_FLOAT16: return nir_type_float16; | |
| case GLSL_TYPE_BFLOAT16: return nir_type_uint16; | |
| + case GLSL_TYPE_FLOAT_E4M3FN: return nir_type_uint8; | |
| case GLSL_TYPE_DOUBLE: return nir_type_float64; | |
| /* clang-format on */ | |
| diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py | |
| index c2995cccf09..def4b6f284c 100644 | |
| --- a/src/compiler/nir/nir_opcodes.py | |
| +++ b/src/compiler/nir/nir_opcodes.py | |
| @@ -1765,3 +1765,7 @@ opcode("bfdot2_bfadd", 1, tint16, [2, 2, 1], [tint16, tint16, tint16], | |
| dst.x = _mesa_float_to_bfloat16_bits_rte(acc); | |
| """) | |
| + | |
| + | |
| +unop_numeric_convert("e4m3fn2f", tfloat32, tuint8, "0") # TODO constant fold | |
| +unop_numeric_convert("f2e4m3fn", tuint8, tfloat32, "0") # TODO constant fold | |
| diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c | |
| index 05ccce54db6..7b3275fd493 100644 | |
| --- a/src/compiler/spirv/spirv_to_nir.c | |
| +++ b/src/compiler/spirv/spirv_to_nir.c | |
| @@ -1890,10 +1890,14 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode, | |
| int32_t encoding = count > 3 ? w[3] : -1; | |
| switch (encoding) { | |
| case -1: | |
| - /* No encoding specified, it is a regular FP. */ | |
| - vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64, | |
| + if (bit_size == 8) { | |
| + val->type->type = glsl_e4m3fn_t_type(); | |
| + } else { | |
| + /* No encoding specified, it is a regular FP. */ | |
| + vtn_fail_if(bit_size != 16 && bit_size != 32 && bit_size != 64, | |
| "Invalid float bit size: %u", bit_size); | |
| - val->type->type = glsl_floatN_t_type(bit_size); | |
| + val->type->type = glsl_floatN_t_type(bit_size); | |
| + } | |
| break; | |
| case SpvFPEncodingBFloat16KHR: | |
| diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c | |
| index a528b2e1b12..a89d9b2ad9e 100644 | |
| --- a/src/compiler/spirv/vtn_alu.c | |
| +++ b/src/compiler/spirv/vtn_alu.c | |
| @@ -697,6 +697,23 @@ vtn_handle_convert(struct vtn_builder *b, SpvOp opcode, | |
| return nir_f2bf(&b->nb, src_as_float); | |
| } | |
| + if (glsl_type_is_e4m3fn(glsl_src_type)) { | |
| + nir_def *src_as_float = nir_e4m3fn2f(&b->nb, src); | |
| + if (glsl_type_is_float(glsl_dest_type)) | |
| + return src_as_float; | |
| + return vtn_handle_convert(b, opcode, dest_val, glsl_dest_type, | |
| + glsl_float_type(), src_as_float); | |
| + | |
| + } else if (glsl_type_is_e4m3fn(glsl_dest_type)) { | |
| + nir_def *src_as_float; | |
| + if (glsl_type_is_float(glsl_src_type)) | |
| + src_as_float = src; | |
| + else | |
| + src_as_float = vtn_handle_convert(b, opcode, dest_val, glsl_float_type(), | |
| + glsl_src_type, src); | |
| + return nir_f2e4m3fn(&b->nb, src_as_float); | |
| + } | |
| + | |
| /* Use bit_size from NIR source instead of from the original src type, | |
| * to account for mediump_16bit. See vtn_handle_alu() for details. | |
| */ | |
| diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c | |
| index baf359c9f9f..c474ea9a81c 100644 | |
| --- a/src/compiler/spirv/vtn_variables.c | |
| +++ b/src/compiler/spirv/vtn_variables.c | |
| @@ -716,6 +716,7 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load, | |
| case GLSL_TYPE_INT64: | |
| case GLSL_TYPE_FLOAT: | |
| case GLSL_TYPE_FLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_BFLOAT16: | |
| case GLSL_TYPE_BOOL: | |
| case GLSL_TYPE_DOUBLE: | |
| @@ -811,6 +812,7 @@ _vtn_variable_copy(struct vtn_builder *b, struct vtn_pointer *dest, | |
| case GLSL_TYPE_FLOAT: | |
| case GLSL_TYPE_FLOAT16: | |
| case GLSL_TYPE_BFLOAT16: | |
| + case GLSL_TYPE_FLOAT_E4M3FN: | |
| case GLSL_TYPE_DOUBLE: | |
| case GLSL_TYPE_BOOL: | |
| /* At this point, we have a scalar, vector, or matrix so we know that |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment