Last active
February 17, 2026 05:37
-
-
Save j2kun/986af859eefa3431c49dfbe4ab5035aa to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| commit 3f899743ce5d6d8c423be7aa0a6033162addb5ed | |
| Author: Jeremy Kun <j2kun@users.noreply.github.com> | |
| Date: Mon Feb 16 21:32:47 2026 -0800 | |
| remove mod_arith dep from RNSAttributes | |
| diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | |
| index 463f07c16..a7b3216f2 100644 | |
| --- a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | |
| +++ b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp | |
| @@ -253,7 +253,17 @@ static LogicalResult verifyNTTOp(Operation* op, PolynomialType input, | |
| for (int i = 0; i < rnsLength; i++) { | |
| auto limbType = dyn_cast<mod_arith::ModArithType>(basis[i]); | |
| APInt cmod = limbType.getModulus().getValue(); | |
| - APInt rootValue = rootValueType.getValues()[i].getValue(); | |
| + mod_arith::ModArithAttr rootLimbValue = | |
| + dyn_cast<mod_arith::ModArithAttr>( | |
| + rootValueType.getValues()[i]); | |
| + if (!rootLimbValue || rootLimbValue.getType() != limbType) { | |
| + return op->emitOpError() | |
| + << "Ring has coefficient type " | |
| + << inputRing.getCoefficientType() | |
| + << ", but primitive root attr had incorrect limb[" << i | |
| + << "] = " << rootValueType.getValues()[i]; | |
| + } | |
| + APInt rootValue = rootLimbValue.getValue().getValue(); | |
| if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) { | |
| return op->emitOpError() | |
| << "provided root " << rootValue.getZExtValue() | |
| diff --git a/lib/Dialect/RNS/IR/BUILD b/lib/Dialect/RNS/IR/BUILD | |
| index 267871e51..9161d159d 100644 | |
| --- a/lib/Dialect/RNS/IR/BUILD | |
| +++ b/lib/Dialect/RNS/IR/BUILD | |
| @@ -30,7 +30,6 @@ cc_library( | |
| ":ops_inc_gen", | |
| ":type_interfaces_inc_gen", | |
| ":types_inc_gen", | |
| - "@heir//lib/Dialect/ModArith/IR:Types", | |
| "@llvm-project//llvm:Support", | |
| "@llvm-project//mlir:IR", | |
| "@llvm-project//mlir:InferTypeOpInterface", | |
| diff --git a/lib/Dialect/RNS/IR/RNSAttributes.cpp b/lib/Dialect/RNS/IR/RNSAttributes.cpp | |
| index b92cdd3dd..2a50bda3a 100644 | |
| --- a/lib/Dialect/RNS/IR/RNSAttributes.cpp | |
| +++ b/lib/Dialect/RNS/IR/RNSAttributes.cpp | |
| @@ -1,6 +1,5 @@ | |
| #include "lib/Dialect/RNS/IR/RNSAttributes.h" | |
| -#include "lib/Dialect/ModArith/IR/ModArithTypes.h" | |
| #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project | |
| #include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project | |
| #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | |
| @@ -9,10 +8,8 @@ namespace mlir { | |
| namespace heir { | |
| namespace rns { | |
| -LogicalResult RNSAttr::verify( | |
| - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, | |
| - ::llvm::ArrayRef<mlir::IntegerAttr> values, | |
| - ::mlir::heir::rns::RNSType type) { | |
| +LogicalResult RNSAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |
| + ArrayRef<Attribute> values, RNSType type) { | |
| auto basisSize = type.getBasisTypes().size(); | |
| if (values.size() != basisSize) { | |
| return emitError() << "expected " << basisSize | |
| @@ -22,51 +19,40 @@ LogicalResult RNSAttr::verify( | |
| return success(); | |
| } | |
| -void RNSAttr::print(mlir::AsmPrinter &printer) const { | |
| +void RNSAttr::print(AsmPrinter &printer) const { | |
| printer << "<["; | |
| // Use llvm::interleaveComma to handle the commas between elements nicely | |
| - llvm::interleaveComma(getValues(), printer, [&](mlir::IntegerAttr attr) { | |
| - printer << attr.getValue(); | |
| - }); | |
| - printer << "] : " << getType() << ">"; | |
| + llvm::interleaveComma(getValues(), printer, | |
| + [&](Attribute attr) { printer << attr; }); | |
| + printer << "]>"; | |
| } | |
| -mlir::Attribute RNSAttr::parse(mlir::AsmParser &parser, mlir::Type type) { | |
| - llvm::SmallVector<APInt> rawValues; | |
| - RNSType rnsType; | |
| +Attribute RNSAttr::parse(AsmParser &parser, Type type) { | |
| + SmallVector<TypedAttr> attrs; | |
| if (parser.parseLess() || parser.parseLSquare()) return {}; | |
| - // 1. Parse comma-separated integers: 3, 5, 7 | |
| auto elementParser = [&]() { | |
| - APInt val; | |
| - if (parser.parseInteger(val)) return mlir::failure(); | |
| - rawValues.push_back(val); | |
| - return mlir::success(); | |
| + Attribute val; | |
| + if (parser.parseAttribute(val)) return failure(); | |
| + if (auto typedVal = dyn_cast<TypedAttr>(val)) { | |
| + attrs.push_back(typedVal); | |
| + return success(); | |
| + } | |
| + return failure(); | |
| }; | |
| if (parser.parseCommaSeparatedList(elementParser)) return {}; | |
| - if (parser.parseRSquare() || parser.parseColon()) return {}; | |
| - if (parser.parseType(rnsType)) return {}; | |
| - if (parser.parseGreater()) return {}; | |
| + if (parser.parseRSquare() || parser.parseGreater()) return {}; | |
| - llvm::SmallVector<mlir::IntegerAttr> sizedValues; | |
| - auto basisTypes = rnsType.getBasisTypes(); | |
| - for (auto [val, basisTy] : llvm::zip(rawValues, basisTypes)) { | |
| - auto modArithTy = llvm::dyn_cast<mod_arith::ModArithType>(basisTy); | |
| - if (!modArithTy) { | |
| - parser.emitError(parser.getNameLoc()) | |
| - << "basis type is not a ModArithType"; | |
| - return {}; | |
| - } | |
| - mlir::Type integerType = modArithTy.getModulus().getType(); | |
| - unsigned targetBitWidth = integerType.getIntOrFloatBitWidth(); | |
| - sizedValues.push_back( | |
| - mlir::IntegerAttr::get(integerType, val.zextOrTrunc(targetBitWidth))); | |
| - } | |
| + // The rns type can be inferred from the types of the attribute values | |
| + SmallVector<Type> basisTypes = | |
| + map_to_vector(attrs, [](TypedAttr attr) { return attr.getType(); }); | |
| + RNSType rnsType = RNSType::get(parser.getContext(), basisTypes); | |
| + SmallVector<Attribute> attrValues = map_to_vector( | |
| + attrs, [](TypedAttr attr) { return cast<Attribute>(attr); }); | |
| return RNSAttr::getChecked( | |
| [&]() { return parser.emitError(parser.getNameLoc()); }, | |
| - parser.getContext(), llvm::ArrayRef<mlir::IntegerAttr>(sizedValues), | |
| - rnsType); | |
| + parser.getContext(), ArrayRef<Attribute>(attrValues), rnsType); | |
| } | |
| } // namespace rns | |
| diff --git a/lib/Dialect/RNS/IR/RNSAttributes.td b/lib/Dialect/RNS/IR/RNSAttributes.td | |
| index 8ce33405d..bbc313a34 100644 | |
| --- a/lib/Dialect/RNS/IR/RNSAttributes.td | |
| +++ b/lib/Dialect/RNS/IR/RNSAttributes.td | |
| @@ -25,11 +25,13 @@ def RNS_RNSAttr : RNS_Attr<"RNS", "value"> { | |
| Example: | |
| ```mlir | |
| - #v = #rns.value<[3, 9] : !rns.rns<...>> | |
| + #v1 = #mod_arith.value<17 : !mod_arith.int<256 : i32>> | |
| + #v2 = #mod_arith.value<19 : !mod_arith.int<256 : i32>> | |
| + #v = #rns.value<[#v1, #v2]> | |
| ``` | |
| }]; | |
| let parameters = (ins | |
| - ArrayRefParameter<"IntegerAttr">:$values, | |
| + ArrayRefParameter<"Attribute">:$values, | |
| "::mlir::heir::rns::RNSType":$type | |
| ); | |
| let genVerifyDecl = 1; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment