Skip to content

Instantly share code, notes, and snippets.

@javidcf
Created February 13, 2025 16:54
Show Gist options
  • Select an option

  • Save javidcf/b16285c84c3b8744e228f476c1964591 to your computer and use it in GitHub Desktop.

Select an option

Save javidcf/b16285c84c3b8744e228f476c1964591 to your computer and use it in GitHub Desktop.
// lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
// ...
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
OnnxCustomOpConversionPattern &patterns) {
// ...
patterns.onOp(
"Loop", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// Get all operands (maxTripCount, cond, ....inits....)
llvm::SmallVector<Value> operands;
if (binder.tensorOperandsList(operands) || operands.size() == 0 ||
binder.getNumOperands() < 2) {
return rewriter.notifyMatchFailure(binder.op,
"Failed to get required operands");
}
llvm::SmallVector<mlir::Type> operandTypeVec;
if (binder.tensorOperandTypes(operandTypeVec) ||
operandTypeVec.size() == 0) {
return rewriter.notifyMatchFailure(binder.op,
"Failed to get operandTypes");
}
Region *loopBodyIn;
if (binder.getRegionAtIndex(loopBodyIn, 0)) {
return rewriter.notifyMatchFailure(binder.op,
"Failed getting LoopBody Region");
}
// MaxTripCount - tensor int64 scalar (or empty)
Value maxTripCountTensor = operands[0];
auto maxTripCountInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
maxTripCountTensor);
// Condition - tensor bool scalar (or empty)
Value conditionTensor = operands[1];
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);
// To be used for "for like" loop case
auto constBoolTrue = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getBoolAttr(true));
// Others (if present) - variadic (can be tensors and scalar values)
if (binder.getNumOperands() > 2) {
operandTypeVec.erase(operandTypeVec.begin(),
operandTypeVec.begin() + 2);
operands.erase(operands.begin(), operands.begin() + 2);
}
// Scan operands - add one list per scan output
OpBinder terminatorInBinder(loopBodyIn->front().getTerminator());
const unsigned numScanOutputs =
terminatorInBinder.getNumOperands() -
(loopBodyIn->front().getNumArguments() - 1);
for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs;
scanOutputIdx++) {
// find scan operand type
const unsigned scanOperandIdx =
loopBodyIn->front().getNumArguments() - 1 + scanOutputIdx;
Value scanOperand;
if (terminatorInBinder.tensorOperandAtIndex(scanOperand,
scanOperandIdx)) {
return rewriter.notifyMatchFailure(
terminatorInBinder.op,
"Failed getting tensor type of scan output");
}
// add new list operand and return type
Torch::ListType scanListType =
rewriter.getType<Torch::ListType>(scanOperand.getType());
operands.push_back(rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), scanListType, llvm::SmallVector<Value>{}));
operandTypeVec.push_back(scanListType);
}
auto getOpName = [](Operation *op) -> std::string {
std::string name = op->getName().getStringRef().str();
if (name != "torch.operator")
return name;
// for unconverted onnx ops
return mlir::dyn_cast<StringAttr>(op->getAttr("name"))
.getValue()
.str();
};
// PrimLoop Op expectes inputCondition to be boolConstantTrue
// to decide if the loopOp is `forlike`. Use loopIsForLike to
// ensure appropriate inputCondition is set
// Case 1 : loopCondInp -> identity -> terminator(loopCondOut)
bool loopIsForLike = false;
auto case1ForLike = [&getOpName](Region *loopBody) -> bool {
Value onnxLoopBodyCondIn = loopBody->front().getArgument(1);
if (!onnxLoopBodyCondIn.hasOneUse())
return false;
Operation *inpCondUser = *onnxLoopBodyCondIn.getUsers().begin();
if (getOpName(inpCondUser) != "onnx.Identity") {
return false;
}
if (!inpCondUser->hasOneUse() ||
getOpName(*(inpCondUser->getUsers().begin())) !=
"torch.operator_terminator")
return false;
return true;
};
loopIsForLike = case1ForLike(loopBodyIn);
Value loopInitCondition =
loopIsForLike ? constBoolTrue : conditionBool.getResult();
auto loc = binder.getLoc();
mlir::ImplicitLocOpBuilder b(loc, rewriter);
auto loop = b.create<Torch::PrimLoopOp>(
TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition,
ValueRange(operands));
rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(),
loop.getRegion().begin());
// primLoopOp loopBody expects torch.int as first arg
// insert torch.int arg in loop body, convert to tensor,
// replace all uses of old arg, delete old arg.
auto loopVarArg = loop.getRegion().front().getArgument(0);
// insert new Arg
loop.getRegion().front().insertArgument(
0U, rewriter.getType<Torch::IntType>(), binder.getLoc());
auto newLoopVarArg = loop.getRegion().front().getArgument(0);
// convert int arg to tensor of original Type
rewriter.setInsertionPointToStart(&loop.getRegion().front());
Value loopVarVal = BlockArgument::Value(loopVarArg);
auto newTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(),
newLoopVarArg);
loopVarArg.replaceAllUsesWith(newTensor);
loop.getRegion().eraseArgument(1);
// primLoopOp loopBody has no condition arg
auto condArg = loop.getRegion().front().getArgument(1);
if (!condArg.use_empty())
condArg.replaceAllUsesWith(conditionTensor);
// scan arguments for loopBody
llvm::SmallVector<mlir::BlockArgument> scanArguments;
scanArguments.reserve(numScanOutputs);
for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs;
scanOutputIdx++) {
auto scanOperandListType =
(operands.end() - (numScanOutputs - scanOutputIdx))->getType();
loop.getRegion().front().addArgument(scanOperandListType,
binder.getLoc());
scanArguments.push_back(
loop.getRegion().front().getArguments().back());
}
// replace terminator
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = loop.getRegion().front().getTerminator();
rewriter.setInsertionPoint(terminator);
// Get remaining operands from onnxLoopBody's terminator Op
// these are all the loop carried dependencies in the loop body
auto terminatorOperands = terminator->getOperands();
llvm::SmallVector<Value> remTerminatorOperands(
terminatorOperands.begin() + 1, terminatorOperands.end());
Value terminatorCond;
if (loopIsForLike) {
terminatorCond = constBoolTrue;
} else {
// Only use when loop is not forlike
Value terminatorCondTensor = terminatorOperands[0];
auto terminatorCondInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
terminatorCondTensor);
auto terminatorCondBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
terminatorCondInt);
terminatorCond = terminatorCondBool.getResult();
}
// add scan outputs to input lists and replace them in terminator
// operands
const unsigned numNonScanOutputs =
remTerminatorOperands.size() - numScanOutputs;
llvm::SmallVector<Value> remTerminatorOperandsWithScanLists(
remTerminatorOperands.begin(),
remTerminatorOperands.begin() + numNonScanOutputs);
for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs;
scanOutputIdx++) {
const unsigned outputIdx = numNonScanOutputs + scanOutputIdx;
auto scanElement = remTerminatorOperands[outputIdx];
auto scanList = rewriter.create<Torch::AtenAppendTOp>(
terminator->getLoc(), scanArguments[scanOutputIdx].getType(),
scanArguments[scanOutputIdx], scanElement);
scanElement.replaceAllUsesExcept(scanList, scanList);
remTerminatorOperandsWithScanLists.push_back(scanList);
}
rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>(
terminator, terminatorCond, remTerminatorOperandsWithScanLists);
loop.getRegion().eraseArgument(1);
// collect results
const unsigned numNonScanResults =
loop.getNumResults() - numScanOutputs;
llvm::SmallVector<Value> loopResults(
loop.result_begin(), loop.result_begin() + numNonScanResults);
// add stack operations for scan results
rewriter.setInsertionPointAfter(loop);
Value constZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs;
scanOutputIdx++) {
const unsigned resultIdx = numNonScanResults + scanOutputIdx;
auto scanResult = loop.getResult(resultIdx);
// use original result type
Torch::ValueTensorType scanStackType;
if (binder.tensorResultTypeAtIndex(scanStackType, resultIdx)) {
return rewriter.notifyMatchFailure(
binder.op, "Failed getting scan result tensor type");
}
loopResults.push_back(rewriter.create<Torch::AtenStackOp>(
binder.getLoc(), scanStackType, scanResult, constZero));
}
rewriter.replaceOp(binder.op, loopResults);
return success();
});
// ...
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment