Skip to content

Instantly share code, notes, and snippets.

@j2kun
Last active February 17, 2026 05:37
Show Gist options
  • Select an option

  • Save j2kun/986af859eefa3431c49dfbe4ab5035aa to your computer and use it in GitHub Desktop.

Select an option

Save j2kun/986af859eefa3431c49dfbe4ab5035aa to your computer and use it in GitHub Desktop.
commit 3f899743ce5d6d8c423be7aa0a6033162addb5ed
Author: Jeremy Kun <j2kun@users.noreply.github.com>
Date: Mon Feb 16 21:32:47 2026 -0800
remove mod_arith dep from RNSAttributes
diff --git a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 463f07c16..a7b3216f2 100644
--- a/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -253,7 +253,17 @@ static LogicalResult verifyNTTOp(Operation* op, PolynomialType input,
for (int i = 0; i < rnsLength; i++) {
auto limbType = dyn_cast<mod_arith::ModArithType>(basis[i]);
APInt cmod = limbType.getModulus().getValue();
- APInt rootValue = rootValueType.getValues()[i].getValue();
+ mod_arith::ModArithAttr rootLimbValue =
+ dyn_cast<mod_arith::ModArithAttr>(
+ rootValueType.getValues()[i]);
+ if (!rootLimbValue || rootLimbValue.getType() != limbType) {
+ return op->emitOpError()
+ << "Ring has coefficient type "
+ << inputRing.getCoefficientType()
+ << ", but primitive root attr had incorrect limb[" << i
+ << "] = " << rootValueType.getValues()[i];
+ }
+ APInt rootValue = rootLimbValue.getValue().getValue();
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
return op->emitOpError()
<< "provided root " << rootValue.getZExtValue()
diff --git a/lib/Dialect/RNS/IR/BUILD b/lib/Dialect/RNS/IR/BUILD
index 267871e51..9161d159d 100644
--- a/lib/Dialect/RNS/IR/BUILD
+++ b/lib/Dialect/RNS/IR/BUILD
@@ -30,7 +30,6 @@ cc_library(
":ops_inc_gen",
":type_interfaces_inc_gen",
":types_inc_gen",
- "@heir//lib/Dialect/ModArith/IR:Types",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
diff --git a/lib/Dialect/RNS/IR/RNSAttributes.cpp b/lib/Dialect/RNS/IR/RNSAttributes.cpp
index b92cdd3dd..2a50bda3a 100644
--- a/lib/Dialect/RNS/IR/RNSAttributes.cpp
+++ b/lib/Dialect/RNS/IR/RNSAttributes.cpp
@@ -1,6 +1,5 @@
#include "lib/Dialect/RNS/IR/RNSAttributes.h"
-#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
@@ -9,10 +8,8 @@ namespace mlir {
namespace heir {
namespace rns {
-LogicalResult RNSAttr::verify(
- ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
- ::llvm::ArrayRef<mlir::IntegerAttr> values,
- ::mlir::heir::rns::RNSType type) {
+LogicalResult RNSAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Attribute> values, RNSType type) {
auto basisSize = type.getBasisTypes().size();
if (values.size() != basisSize) {
return emitError() << "expected " << basisSize
@@ -22,51 +19,40 @@ LogicalResult RNSAttr::verify(
return success();
}
-void RNSAttr::print(mlir::AsmPrinter &printer) const {
+void RNSAttr::print(AsmPrinter &printer) const {
printer << "<[";
// Use llvm::interleaveComma to handle the commas between elements nicely
- llvm::interleaveComma(getValues(), printer, [&](mlir::IntegerAttr attr) {
- printer << attr.getValue();
- });
- printer << "] : " << getType() << ">";
+ llvm::interleaveComma(getValues(), printer,
+ [&](Attribute attr) { printer << attr; });
+ printer << "]>";
}
-mlir::Attribute RNSAttr::parse(mlir::AsmParser &parser, mlir::Type type) {
- llvm::SmallVector<APInt> rawValues;
- RNSType rnsType;
+Attribute RNSAttr::parse(AsmParser &parser, Type type) {
+ SmallVector<TypedAttr> attrs;
if (parser.parseLess() || parser.parseLSquare()) return {};
- // 1. Parse comma-separated integers: 3, 5, 7
auto elementParser = [&]() {
- APInt val;
- if (parser.parseInteger(val)) return mlir::failure();
- rawValues.push_back(val);
- return mlir::success();
+ Attribute val;
+ if (parser.parseAttribute(val)) return failure();
+ if (auto typedVal = dyn_cast<TypedAttr>(val)) {
+ attrs.push_back(typedVal);
+ return success();
+ }
+ return failure();
};
if (parser.parseCommaSeparatedList(elementParser)) return {};
- if (parser.parseRSquare() || parser.parseColon()) return {};
- if (parser.parseType(rnsType)) return {};
- if (parser.parseGreater()) return {};
+ if (parser.parseRSquare() || parser.parseGreater()) return {};
- llvm::SmallVector<mlir::IntegerAttr> sizedValues;
- auto basisTypes = rnsType.getBasisTypes();
- for (auto [val, basisTy] : llvm::zip(rawValues, basisTypes)) {
- auto modArithTy = llvm::dyn_cast<mod_arith::ModArithType>(basisTy);
- if (!modArithTy) {
- parser.emitError(parser.getNameLoc())
- << "basis type is not a ModArithType";
- return {};
- }
- mlir::Type integerType = modArithTy.getModulus().getType();
- unsigned targetBitWidth = integerType.getIntOrFloatBitWidth();
- sizedValues.push_back(
- mlir::IntegerAttr::get(integerType, val.zextOrTrunc(targetBitWidth)));
- }
+ // The rns type can be inferred from the types of the attribute values
+ SmallVector<Type> basisTypes =
+ map_to_vector(attrs, [](TypedAttr attr) { return attr.getType(); });
+ RNSType rnsType = RNSType::get(parser.getContext(), basisTypes);
+ SmallVector<Attribute> attrValues = map_to_vector(
+ attrs, [](TypedAttr attr) { return cast<Attribute>(attr); });
return RNSAttr::getChecked(
[&]() { return parser.emitError(parser.getNameLoc()); },
- parser.getContext(), llvm::ArrayRef<mlir::IntegerAttr>(sizedValues),
- rnsType);
+ parser.getContext(), ArrayRef<Attribute>(attrValues), rnsType);
}
} // namespace rns
diff --git a/lib/Dialect/RNS/IR/RNSAttributes.td b/lib/Dialect/RNS/IR/RNSAttributes.td
index 8ce33405d..bbc313a34 100644
--- a/lib/Dialect/RNS/IR/RNSAttributes.td
+++ b/lib/Dialect/RNS/IR/RNSAttributes.td
@@ -25,11 +25,13 @@ def RNS_RNSAttr : RNS_Attr<"RNS", "value"> {
Example:
```mlir
- #v = #rns.value<[3, 9] : !rns.rns<...>>
+ #v1 = #mod_arith.value<17 : !mod_arith.int<256 : i32>>
+ #v2 = #mod_arith.value<19 : !mod_arith.int<256 : i32>>
+ #v = #rns.value<[#v1, #v2]>
```
}];
let parameters = (ins
- ArrayRefParameter<"IntegerAttr">:$values,
+ ArrayRefParameter<"Attribute">:$values,
"::mlir::heir::rns::RNSType":$type
);
let genVerifyDecl = 1;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment