Created
February 13, 2025 16:54
-
-
Save javidcf/b16285c84c3b8744e228f476c1964591 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | |
| // ... | |
| void mlir::torch::onnx_c::populateDefaultDomainGtoP( | |
| OnnxCustomOpConversionPattern &patterns) { | |
| // ... | |
| patterns.onOp( | |
| "Loop", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { | |
| // Get all operands (maxTripCount, cond, ....inits....) | |
| llvm::SmallVector<Value> operands; | |
| if (binder.tensorOperandsList(operands) || operands.size() == 0 || | |
| binder.getNumOperands() < 2) { | |
| return rewriter.notifyMatchFailure(binder.op, | |
| "Failed to get required operands"); | |
| } | |
| llvm::SmallVector<mlir::Type> operandTypeVec; | |
| if (binder.tensorOperandTypes(operandTypeVec) || | |
| operandTypeVec.size() == 0) { | |
| return rewriter.notifyMatchFailure(binder.op, | |
| "Failed to get operandTypes"); | |
| } | |
| Region *loopBodyIn; | |
| if (binder.getRegionAtIndex(loopBodyIn, 0)) { | |
| return rewriter.notifyMatchFailure(binder.op, | |
| "Failed getting LoopBody Region"); | |
| } | |
| // MaxTripCount - tensor int64 scalar (or empty) | |
| Value maxTripCountTensor = operands[0]; | |
| auto maxTripCountInt = rewriter.create<Torch::AtenItemOp>( | |
| binder.getLoc(), rewriter.getType<Torch::IntType>(), | |
| maxTripCountTensor); | |
| // Condition - tensor bool scalar (or empty) | |
| Value conditionTensor = operands[1]; | |
| auto conditionInt = rewriter.create<Torch::AtenItemOp>( | |
| binder.getLoc(), rewriter.getType<Torch::IntType>(), | |
| conditionTensor); | |
| auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>( | |
| binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt); | |
| // To be used for "for like" loop case | |
| auto constBoolTrue = rewriter.create<Torch::ConstantBoolOp>( | |
| binder.getLoc(), rewriter.getBoolAttr(true)); | |
| // Others (if present) - variadic (can be tensors and scalar values) | |
| if (binder.getNumOperands() > 2) { | |
| operandTypeVec.erase(operandTypeVec.begin(), | |
| operandTypeVec.begin() + 2); | |
| operands.erase(operands.begin(), operands.begin() + 2); | |
| } | |
| // Scan operands - add one list per scan output | |
| OpBinder terminatorInBinder(loopBodyIn->front().getTerminator()); | |
| const unsigned numScanOutputs = | |
| terminatorInBinder.getNumOperands() - | |
| (loopBodyIn->front().getNumArguments() - 1); | |
| for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs; | |
| scanOutputIdx++) { | |
| // find scan operand type | |
| const unsigned scanOperandIdx = | |
| loopBodyIn->front().getNumArguments() - 1 + scanOutputIdx; | |
| Value scanOperand; | |
| if (terminatorInBinder.tensorOperandAtIndex(scanOperand, | |
| scanOperandIdx)) { | |
| return rewriter.notifyMatchFailure( | |
| terminatorInBinder.op, | |
| "Failed getting tensor type of scan output"); | |
| } | |
| // add new list operand and return type | |
| Torch::ListType scanListType = | |
| rewriter.getType<Torch::ListType>(scanOperand.getType()); | |
| operands.push_back(rewriter.create<Torch::PrimListConstructOp>( | |
| binder.getLoc(), scanListType, llvm::SmallVector<Value>{})); | |
| operandTypeVec.push_back(scanListType); | |
| } | |
| auto getOpName = [](Operation *op) -> std::string { | |
| std::string name = op->getName().getStringRef().str(); | |
| if (name != "torch.operator") | |
| return name; | |
| // for unconverted onnx ops | |
| return mlir::dyn_cast<StringAttr>(op->getAttr("name")) | |
| .getValue() | |
| .str(); | |
| }; | |
| // PrimLoop Op expectes inputCondition to be boolConstantTrue | |
| // to decide if the loopOp is `forlike`. Use loopIsForLike to | |
| // ensure appropriate inputCondition is set | |
| // Case 1 : loopCondInp -> identity -> terminator(loopCondOut) | |
| bool loopIsForLike = false; | |
| auto case1ForLike = [&getOpName](Region *loopBody) -> bool { | |
| Value onnxLoopBodyCondIn = loopBody->front().getArgument(1); | |
| if (!onnxLoopBodyCondIn.hasOneUse()) | |
| return false; | |
| Operation *inpCondUser = *onnxLoopBodyCondIn.getUsers().begin(); | |
| if (getOpName(inpCondUser) != "onnx.Identity") { | |
| return false; | |
| } | |
| if (!inpCondUser->hasOneUse() || | |
| getOpName(*(inpCondUser->getUsers().begin())) != | |
| "torch.operator_terminator") | |
| return false; | |
| return true; | |
| }; | |
| loopIsForLike = case1ForLike(loopBodyIn); | |
| Value loopInitCondition = | |
| loopIsForLike ? constBoolTrue : conditionBool.getResult(); | |
| auto loc = binder.getLoc(); | |
| mlir::ImplicitLocOpBuilder b(loc, rewriter); | |
| auto loop = b.create<Torch::PrimLoopOp>( | |
| TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition, | |
| ValueRange(operands)); | |
| rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(), | |
| loop.getRegion().begin()); | |
| // primLoopOp loopBody expects torch.int as first arg | |
| // insert torch.int arg in loop body, convert to tensor, | |
| // replace all uses of old arg, delete old arg. | |
| auto loopVarArg = loop.getRegion().front().getArgument(0); | |
| // insert new Arg | |
| loop.getRegion().front().insertArgument( | |
| 0U, rewriter.getType<Torch::IntType>(), binder.getLoc()); | |
| auto newLoopVarArg = loop.getRegion().front().getArgument(0); | |
| // convert int arg to tensor of original Type | |
| rewriter.setInsertionPointToStart(&loop.getRegion().front()); | |
| Value loopVarVal = BlockArgument::Value(loopVarArg); | |
| auto newTensor = rewriter.create<Torch::PrimNumToTensorScalarOp>( | |
| loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(), | |
| newLoopVarArg); | |
| loopVarArg.replaceAllUsesWith(newTensor); | |
| loop.getRegion().eraseArgument(1); | |
| // primLoopOp loopBody has no condition arg | |
| auto condArg = loop.getRegion().front().getArgument(1); | |
| if (!condArg.use_empty()) | |
| condArg.replaceAllUsesWith(conditionTensor); | |
| // scan arguments for loopBody | |
| llvm::SmallVector<mlir::BlockArgument> scanArguments; | |
| scanArguments.reserve(numScanOutputs); | |
| for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs; | |
| scanOutputIdx++) { | |
| auto scanOperandListType = | |
| (operands.end() - (numScanOutputs - scanOutputIdx))->getType(); | |
| loop.getRegion().front().addArgument(scanOperandListType, | |
| binder.getLoc()); | |
| scanArguments.push_back( | |
| loop.getRegion().front().getArguments().back()); | |
| } | |
| // replace terminator | |
| PatternRewriter::InsertionGuard guard(rewriter); | |
| Operation *terminator = loop.getRegion().front().getTerminator(); | |
| rewriter.setInsertionPoint(terminator); | |
| // Get remaining operands from onnxLoopBody's terminator Op | |
| // these are all the loop carried dependencies in the loop body | |
| auto terminatorOperands = terminator->getOperands(); | |
| llvm::SmallVector<Value> remTerminatorOperands( | |
| terminatorOperands.begin() + 1, terminatorOperands.end()); | |
| Value terminatorCond; | |
| if (loopIsForLike) { | |
| terminatorCond = constBoolTrue; | |
| } else { | |
| // Only use when loop is not forlike | |
| Value terminatorCondTensor = terminatorOperands[0]; | |
| auto terminatorCondInt = rewriter.create<Torch::AtenItemOp>( | |
| binder.getLoc(), rewriter.getType<Torch::IntType>(), | |
| terminatorCondTensor); | |
| auto terminatorCondBool = rewriter.create<Torch::AtenBoolIntOp>( | |
| binder.getLoc(), rewriter.getType<Torch::BoolType>(), | |
| terminatorCondInt); | |
| terminatorCond = terminatorCondBool.getResult(); | |
| } | |
| // add scan outputs to input lists and replace them in terminator | |
| // operands | |
| const unsigned numNonScanOutputs = | |
| remTerminatorOperands.size() - numScanOutputs; | |
| llvm::SmallVector<Value> remTerminatorOperandsWithScanLists( | |
| remTerminatorOperands.begin(), | |
| remTerminatorOperands.begin() + numNonScanOutputs); | |
| for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs; | |
| scanOutputIdx++) { | |
| const unsigned outputIdx = numNonScanOutputs + scanOutputIdx; | |
| auto scanElement = remTerminatorOperands[outputIdx]; | |
| auto scanList = rewriter.create<Torch::AtenAppendTOp>( | |
| terminator->getLoc(), scanArguments[scanOutputIdx].getType(), | |
| scanArguments[scanOutputIdx], scanElement); | |
| scanElement.replaceAllUsesExcept(scanList, scanList); | |
| remTerminatorOperandsWithScanLists.push_back(scanList); | |
| } | |
| rewriter.replaceOpWithNewOp<Torch::PrimLoopConditionOp>( | |
| terminator, terminatorCond, remTerminatorOperandsWithScanLists); | |
| loop.getRegion().eraseArgument(1); | |
| // collect results | |
| const unsigned numNonScanResults = | |
| loop.getNumResults() - numScanOutputs; | |
| llvm::SmallVector<Value> loopResults( | |
| loop.result_begin(), loop.result_begin() + numNonScanResults); | |
| // add stack operations for scan results | |
| rewriter.setInsertionPointAfter(loop); | |
| Value constZero = rewriter.create<Torch::ConstantIntOp>(loc, 0); | |
| for (unsigned scanOutputIdx = 0; scanOutputIdx < numScanOutputs; | |
| scanOutputIdx++) { | |
| const unsigned resultIdx = numNonScanResults + scanOutputIdx; | |
| auto scanResult = loop.getResult(resultIdx); | |
| // use original result type | |
| Torch::ValueTensorType scanStackType; | |
| if (binder.tensorResultTypeAtIndex(scanStackType, resultIdx)) { | |
| return rewriter.notifyMatchFailure( | |
| binder.op, "Failed getting scan result tensor type"); | |
| } | |
| loopResults.push_back(rewriter.create<Torch::AtenStackOp>( | |
| binder.getLoc(), scanStackType, scanResult, constZero)); | |
| } | |
| rewriter.replaceOp(binder.op, loopResults); | |
| return success(); | |
| }); | |
| // ... | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment