diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index a21c9a1d7296..7ed50e866492 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -27,7 +27,7 @@ runs: steps: - name: Set up Python if: ${{ runner.arch == 'X64' }} - uses: actions/setup-python@v4 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.11' @@ -74,7 +74,7 @@ runs: - name: Enable ccache if: ${{ inputs.cache-enabled == 'true' }} - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ github.workspace }}/.ccache key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000000..3ab6783bdb61 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,15 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "gitsubmodule" + directory: "/" + allow: + - dependency-name: "externals/llvm-project" + schedule: + interval: "daily" + time: "06:00" + timezone: "Europe/Berlin" diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 1c0f8f568728..454142eb44dd 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -22,7 +22,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -53,19 +53,19 @@ jobs: sudo apt-get install unzip # Fetch the most recent nightly torchvision release - VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') + VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') echo "Found torchvision release ${VISION_RELEASE}" # Fetch the whl file associated with the nightly torchvision release rm -f torch*.whl - python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torchvision==${VISION_RELEASE}" + python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre "torchvision==${VISION_RELEASE}" # Downloading the torchvision WHL also downloads the PyTorch WHL file # Read the version from the downloaded whl file without extracting it PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') echo "Found torch release ${PT_RELEASE}" - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torchvision\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt # Read the commit hash from the downloaded whl file without extracting it PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | tail -1 | awk '{ print $3 }' | tr -d "'") @@ -95,7 +95,7 @@ jobs: - name: Post issue comment on build failure if: failure() - uses: peter-evans/create-or-update-comment@v2 + uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0 with: issue-number: 1690 body: | @@ -111,7 +111,7 @@ jobs: - name: Update PyTorch Build Cache (if running on main branch) if: github.ref_name == 'main' id: cache-pytorch - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} @@ -127,7 +127,7 @@ jobs: git pull origin main - name: Create pull request - uses: peter-evans/create-pull-request@v5.0.1 + uses: peter-evans/create-pull-request@67ccf781d68cd99b580ae25a5c18a1cc84ffff1f # v7.0.6 with: author: Roll PyTorch Action branch: rollpytorch diff --git a/.github/workflows/approve_dependabot.yml b/.github/workflows/approve_dependabot.yml new file mode 100644 index 000000000000..ca3f6b6e9930 --- /dev/null +++ b/.github/workflows/approve_dependabot.yml @@ -0,0 +1,28 @@ +name: Dependabot auto-approve & auto-merge +on: pull_request + +permissions: + pull-requests: write + # Needed to enable auto-merge + contents: write + +jobs: + dependabot: + runs-on: ubuntu-latest + if: github.actor == 'dependabot[bot]' + steps: + - name: Dependabot metadata + id: metadata + uses: dependabot/fetch-metadata@v2 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + - name: Approve a PR + run: gh pr review --approve "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GH_TOKEN: ${{secrets.GITHUB_TOKEN}} + - name: Enable auto-merge for Dependabot PRs + run: gh pr merge --auto --merge "$PR_URL" + env: + PR_URL: ${{github.event.pull_request.html_url}} + GH_TOKEN: ${{secrets.GITHUB_TOKEN}} diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 23f2addbe5af..030dde79fc51 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -22,7 +22,7 @@ concurrency: jobs: ubuntu-build: name: ubuntu-x86_64 - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Prepare workspace @@ -32,7 +32,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checkout torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' @@ -40,7 +40,7 @@ jobs: # restore to avoid the cache going stale over time # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache - name: Setup cache for bazel - uses: actions/cache@v3 + uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ~/.cache/bazel key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} @@ -102,7 +102,7 @@ jobs: - name: Send mail if: failure() - uses: dawidd6/action-send-mail@v3 + uses: dawidd6/action-send-mail@611879133a9569642c41be66f4a323286e9b8a3b # v4 with: server_address: ${{ secrets.SMTP_SERVER }} server_port: ${{ secrets.SMTP_PORT }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index e84aabb4b388..f6a99614e68c 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -14,11 +14,16 @@ on: jobs: build_linux: name: Manylinux x86_64 Build - runs-on: a100 + runs-on: ubuntu-latest + permissions: + contents: write + actions: write + packages: write strategy: matrix: package: [torch-mlir] - py_version: [cp38-cp38, cp311-cp311] + py_version: [cp38-cp38, cp310-cp310] # cp311-cp311 + torch-version: [stable] # nightly steps: - name: Prepare workspace @@ -28,7 +33,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' fetch-depth: 0 @@ -38,19 +43,22 @@ jobs: cache-enabled: 'false' - name: Build Python wheels and smoke test. run: | - cd $GITHUB_WORKSPACE - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh - + cd $GITHUB_WORKSPACE + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + TM_SKIP_TESTS=ON \ + TM_PYTHON_VERSIONS=${{ matrix.py_version }} \ + TM_PACKAGES=${{ matrix.package }} \ + TM_TORCH_VERSION="${{ matrix.torch-version }}" \ + ./build_tools/python_deploy/build_linux_packages.sh # If we were given a release_id, then upload the package we just built # to the github releases page. - name: Upload Release Assets (if requested) if: github.event.inputs.release_id != '' id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 + uses: dwenegar/upload-release-assets@v3 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl @@ -59,9 +67,9 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} - name: Create dist directory @@ -75,12 +83,13 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: wheels path: dist build_linux_arm64: + if: false name: Manylinux arm64 Build runs-on: linux-arm64 strategy: @@ -96,7 +105,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' fetch-depth: 0 @@ -116,9 +125,9 @@ jobs: - name: Upload Release Assets (if requested) if: github.event.inputs.release_id != '' id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 + uses: dwenegar/upload-release-assets@v3 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl @@ -127,9 +136,9 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} - name: Create dist directory @@ -143,12 +152,13 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: wheels path: dist build_macos: + if: false name: MacOS Build runs-on: macos-latest strategy: @@ -156,7 +166,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -176,9 +186,9 @@ jobs: - name: Upload Release Assets (if requested) if: github.event.inputs.release_id != '' id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 + uses: dwenegar/upload-release-assets@v3 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl @@ -187,9 +197,9 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} - name: Create dist directory @@ -203,12 +213,13 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: wheels path: dist build_windows: + if: false name: Windows Build runs-on: windows-latest strategy: @@ -216,7 +227,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -239,9 +250,9 @@ jobs: - name: Upload Release Assets (if requested) if: github.event.inputs.release_id != '' id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 + uses: dwenegar/upload-release-assets@v3 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} assets_path: ./wheelhouse/torch*.whl @@ -250,9 +261,9 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@v1 + uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: release_id: ${{ github.event.inputs.release_id }} - name: Create dist directory @@ -267,28 +278,32 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0 with: name: wheels path: dist publish_releases: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 + permissions: + contents: write + actions: write + packages: write needs: - build_linux - - build_linux_arm64 - - build_macos - - build_windows + #- build_linux_arm64 + #- build_macos + #- build_windows # Publish even if one of the builds failed if: ${{ always() }} steps: - name: Invoke Publish Releases Page - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Publish releases page - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + token: ${{ secrets.GITHUB_TOKEN }} # Wheels must be published from a linux environment. # diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 63ef01cdeb51..2a98944240ba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,9 +5,8 @@ on: workflow_dispatch: workflow_call: pull_request: - branches: [main] push: - branches: [main] + branches: [main, feature/*] concurrency: # A PR number if a pull request and otherwise the commit hash. This cancels @@ -19,59 +18,75 @@ concurrency: jobs: build-test-linux: strategy: - fail-fast: true + # AMD: Disable fail-fast to see whether failures are different between stable & nightly + fail-fast: false matrix: torch-version: [nightly, stable] name: Build and Test (Linux, torch-${{ matrix.torch-version }}, assertions) - runs-on: torch-mlir-cpubuilder-manylinux-x86-64 + runs-on: ubuntu-22.04 env: CACHE_DIR: ${{ github.workspace }}/.container-cache steps: - - name: Configure local git mirrors - run: | - # Our stock runners have access to certain local git caches. If these - # files are available, it will prime the cache and configure git to - # use them. Practically, this eliminates network/latency for cloning - # llvm. - if [[ -x /gitmirror/scripts/trigger_update_mirrors.sh ]]; then - /gitmirror/scripts/trigger_update_mirrors.sh - /gitmirror/scripts/git_config.sh - fi - name: "Checking out repository" - uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + - name: Runner setup + run: | + sudo apt-get update + sudo apt-get install -y ccache clang + - name: Enable cache - uses: actions/cache/restore@v3 + uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 with: path: ${{ env.CACHE_DIR }} key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }} restore-keys: | build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2- + - name: "Setting up Python" # AMD: python 3.10 and not 3.11 + run: | + sudo apt update + sudo apt install python3.10 python3-pip -y + sudo apt-get install python3.10-dev python3.10-venv build-essential -y + - name: Install python deps (torch-${{ matrix.torch-version }}) run: | export cache_dir="${{ env.CACHE_DIR }}" bash build_tools/ci/install_python_deps.sh ${{ matrix.torch-version }} + - name: ccache + uses: hendrikmuhs/ccache-action@53911442209d5c18de8a31615e0923161e435875 # v1.2.16 + with: + key: ${{ github.job }}-${{ matrix.torch-version }} + save: ${{ needs.setup.outputs.write-caches == 1 }} + - name: Build project run: | export cache_dir="${{ env.CACHE_DIR }}" bash build_tools/ci/build_posix.sh - name: Save cache - uses: actions/cache/save@v3 + uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 if: ${{ !cancelled() }} with: path: ${{ env.CACHE_DIR }} key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }} - name: Integration tests (torch-${{ matrix.torch-version }}) + if: ${{ matrix.torch-version == 'nightly' }} + continue-on-error: true + run: | + bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }} + + - name: Integration tests (torch-${{ matrix.torch-version }}) + if: ${{ matrix.torch-version != 'nightly' }} run: | bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }} - name: Check generated sources (torch-nightly only) if: ${{ matrix.torch-version == 'nightly' }} + continue-on-error: true run: | bash build_tools/ci/check_generated_sources.sh diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index a0eb45257b11..4e7e17d0ee0c 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -8,10 +8,12 @@ on: jobs: scrape_and_publish_releases: name: "Scrape and publish releases" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 + permissions: + contents: write # Don't run this in everyone's forks. - if: github.repository == 'llvm/torch-mlir' + if: github.repository == 'xilinx/torch-mlir' steps: - name: Prepare workspace @@ -20,11 +22,9 @@ jobs: # existing lock files. sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 - with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run scrape releases script - run: python ./build_tools/scrape_releases.py llvm torch-mlir > /tmp/index.html + run: python ./build_tools/scrape_releases.py xilinx torch-mlir > /tmp/index.html shell: bash - run: git fetch --all - run: git switch github-pages @@ -37,7 +37,7 @@ jobs: - run: git diff --cached --exit-code || git commit -m "Update releases." - name: GitHub Push - uses: ad-m/github-push-action@v0.6.0 + uses: ad-m/github-push-action@v0.8.0 with: github_token: ${{ secrets.GITHUB_TOKEN }} branch: github-pages diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 58a91fd1d409..e335f1fdfd7d 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -9,7 +9,7 @@ on: jobs: merge-pr: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 if: | github.repository == 'llvm/torch-mlir' && github.event.workflow_run.actor.login == 'stellaraccident' && @@ -18,7 +18,7 @@ jobs: steps: # Fetch the repo first so that the gh command knows where to look for the PR - name: Fetch Repo - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index ec1878606624..23c10bb9384f 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -9,7 +9,7 @@ jobs: name: "Tag snapshot release" runs-on: ubuntu-latest # Don't run this in everyone's forks. - if: github.repository == 'llvm/torch-mlir' + #if: github.repository == 'llvm/torch-mlir' steps: - name: Prepare workspace run: | @@ -18,9 +18,10 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + submodules: 'true' + fetch-depth: 0 - name: Compute version run: | @@ -35,7 +36,7 @@ jobs: git tag "${tag_name}" - name: Pushing changes - uses: ad-m/github-push-action@v0.6.0 + uses: ad-m/github-push-action@v0.8.0 with: github_token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} branch: ${{ github.ref_name }} @@ -43,16 +44,15 @@ jobs: - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: ncipollo/release-action@cdcc88a9acf3ca41c16c37bb7d21b9ad48560d87 # v1.15.0 with: - tag_name: ${{ env.tag_name }} - release_name: torch-mlir snapshot ${{ env.tag_name }} + tag: ${{ env.tag_name }} + name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" uses: benc-uk/workflow-dispatch@v1 diff --git a/.github/workflows/pre-commit-all.yml b/.github/workflows/pre-commit-all.yml index e17d4ebdbb43..2c0d61e92747 100644 --- a/.github/workflows/pre-commit-all.yml +++ b/.github/workflows/pre-commit-all.yml @@ -6,10 +6,10 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --color=always --all-files diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 29733c2e5d45..6a848fe8674f 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,13 +5,13 @@ on: jobs: pre-commit: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: # requites to grab the history of the PR fetch-depth: 0 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 with: extra_args: --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 8a0ec914440f..63047c77e716 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -9,9 +9,14 @@ on: jobs: release_snapshot_package: name: "Tag snapshot release" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 # Don't run this in everyone's forks. - if: github.repository == 'llvm/torch-mlir' + #if: github.repository == 'llvm/torch-mlir' + permissions: + contents: write + actions: write + env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} steps: - name: Prepare workspace @@ -21,9 +26,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@v3 - with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Compute version run: | @@ -38,36 +41,49 @@ jobs: git tag "${tag_name}" - name: Pushing changes - uses: ad-m/github-push-action@v0.6.0 + uses: ad-m/github-push-action@v0.8.0 with: - github_token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - branch: main + github_token: ${{ secrets.GITHUB_TOKEN }} + branch: ${{ env.BRANCH_NAME }} tags: true - name: Create Release id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + uses: ncipollo/release-action@cdcc88a9acf3ca41c16c37bb7d21b9ad48560d87 # v1.15.0 with: - tag_name: ${{ env.tag_name }} - release_name: torch-mlir snapshot ${{ env.tag_name }} + tag: ${{ env.tag_name }} + name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false + token: ${{ secrets.GITHUB_TOKEN }} - name: "Invoke workflow :: Build and Test" - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Build and Test - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + token: ${{ secrets.GITHUB_TOKEN }} ref: "${{ env.tag_name }}" - name: "Invoke workflow :: Release Build" - uses: benc-uk/workflow-dispatch@v1 + uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 with: workflow: Release Build - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} ref: "${{ env.tag_name }}" inputs: '{"release_id": "${{ steps.create_release.outputs.id }}", "python_package_version": "${{ env.package_version }}"}' + + - name: Download nightly pytorch and torchvision wheels + run: | + pip download -r pytorch-requirements.txt -r torchvision-requirements.txt --no-deps --dest deps --python-version 3.8 + pip download -r pytorch-requirements.txt -r torchvision-requirements.txt --no-deps --dest deps --python-version 3.10 + pip download -r pytorch-requirements.txt -r torchvision-requirements.txt --no-deps --dest deps --python-version 3.11 + + - name: Upload nightly pytorch and torchvision wheels into release + id: upload-release-assets-nightly + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + release_id: ${{ steps.create_release.outputs.id }} + assets_path: ./deps/*.whl diff --git a/.gitignore b/.gitignore index 00a5bc96f221..7cc823a3fe28 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ externals/pytorch/ libtorch* /build/ +.build-cache/ /setup_build/ __pycache__ *.pyc diff --git a/.gitmodules b/.gitmodules index 8b46098d9615..f685f95dfa82 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,7 @@ [submodule "externals/llvm-project"] path = externals/llvm-project - url = https://github.com/llvm/llvm-project.git + url = https://github.com/Xilinx/llvm-aie.git + branch = aie-public [submodule "externals/stablehlo"] path = externals/stablehlo url = https://github.com/openxla/stablehlo.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 4740f2312394..c0f940467630 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,32 +35,50 @@ option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directo option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON) option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON) + if(TORCH_MLIR_ENABLE_REFBACKEND) add_definitions(-DTORCH_MLIR_ENABLE_REFBACKEND) endif() +set(TORCH_MLIR_TABLEGEN_FLAGS "") + option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) if(TORCH_MLIR_ENABLE_STABLEHLO) add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_STABLEHLO") +endif() +# It is possible that both stablehlo and torch_mlir projects are used in some compiler project. +# In this case, we don't want to use stablehlo that is downloaded by torch_mlir (in external/stablehlo) +# folder but instead want to use stablehlo that is part of top level compiler project. +# With TORCH_MLIR_USE_EXTERNAL_STABLEHLO enables, it is assumed that top level compiler project makes +# stablehlo targets AND includes available (for example with `add_subdirectory` and `include_directories`). +option(TORCH_MLIR_USE_EXTERNAL_STABLEHLO "Use stablehlo from top level project" OFF) + +option(TORCH_MLIR_ENABLE_TOSA "Add TOSA support" ON) +if(TORCH_MLIR_ENABLE_TOSA) + add_definitions(-DTORCH_MLIR_ENABLE_TOSA) + list(APPEND TORCH_MLIR_TABLEGEN_FLAGS "-DTORCH_MLIR_ENABLE_TOSA") endif() option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF) # PyTorch native extension gate. If OFF, then no features which depend on -# native extensions will be built. -option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" ON) +# native extensions will be built.TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS is disabled by default. +# But it will be manually enabled in CI build to enable the jit_ir_importer.build_tools.torch_ods_gen +# and abstract_interp_lib_gen.py. Once pure python version of build_tools finished, no need to set it in CI. +option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" OFF) +if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS) + add_definitions(-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS) +endif() +# NOTE: The JIT_IR_IMPORTER paths have become unsupportable due to age and lack of maintainers. +# Turning this off disables the old TorchScript path, leaving FX based import as the current supported option. +# The option will be retained for a time, and if a maintainer is interested in setting up testing for it, +# please reach out on the list and speak up for it. It will only be enabled in CI for test usage. cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) -# TODO(#3299): migrate to from member x.cast() to mlir::cast(x). -if(MSVC) - add_compile_options(/wd4996) -else() - add_compile_options(-Wno-deprecated-declarations) -endif() - macro(torch_mlir_enable_werror) if(TORCH_MLIR_ENABLE_WERROR_FLAG) if(NOT MSVC) @@ -69,6 +87,10 @@ macro(torch_mlir_enable_werror) endif() endmacro() +if(MSVC) + add_definitions(-D_USE_MATH_DEFINES) +endif() + #------------------------------------------------------------------------------- # Configure out-of-tree vs in-tree build #------------------------------------------------------------------------------- @@ -125,10 +147,6 @@ else() set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") endif() -if (TORCH_MLIR_ENABLE_STABLEHLO) - include_directories(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo) -endif() - set(TORCH_MLIR_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") set(TORCH_MLIR_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}") message(STATUS "Building torch-mlir project at ${TORCH_MLIR_SOURCE_DIR} (into ${TORCH_MLIR_BINARY_DIR})") @@ -140,7 +158,8 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) function(torch_mlir_target_includes target) set(_dirs - $ + $ + $ $ $ ) @@ -230,8 +249,10 @@ endif() # Getting this wrong results in building large parts of the stablehlo # project that we don't actually depend on. Further some of those parts # do not even compile on all platforms. -if (TORCH_MLIR_ENABLE_STABLEHLO) +# Only configure StableHLO if it isn't provided from a top-level project +if (TORCH_MLIR_ENABLE_STABLEHLO AND NOT TORCH_MLIR_USE_EXTERNAL_STABLEHLO) set(STABLEHLO_BUILD_EMBEDDED ON) + set(STABLEHLO_ENABLE_BINDINGS_PYTHON ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/stablehlo ${CMAKE_CURRENT_BINARY_DIR}/stablehlo EXCLUDE_FROM_ALL) diff --git a/README.md b/README.md index 70268ba729f0..56371b949487 100644 --- a/README.md +++ b/README.md @@ -21,17 +21,8 @@ Several vendors have adopted MLIR as the middle layer in their systems, enabling ## All the roads from PyTorch to Torch MLIR Dialect We have few paths to lower down to the Torch MLIR Dialect. - -![Simplified Architecture Diagram for README](docs/images/readme_architecture_diagram.png) - - - TorchScript - This is the most tested path down to Torch MLIR Dialect. - - LazyTensorCore - Read more details [here](docs/ltc_backend.md). - - We also have basic TorchDynamo/PyTorch 2.0 support, see our - [long-term roadmap](docs/roadmap.md) and - [Thoughts on PyTorch 2.0](https://discourse.llvm.org/t/thoughts-on-pytorch-2-0/67000/3) - for more details. + - ONNX as the entry points. + - Fx as the entry points ## Project Communication @@ -39,17 +30,6 @@ We have few paths to lower down to the Torch MLIR Dialect. - Github issues [here](https://github.com/llvm/torch-mlir/issues) - [`torch-mlir` section](https://llvm.discourse.group/c/projects-that-want-to-become-official-llvm-projects/torch-mlir/41) of LLVM Discourse -### Meetings - -Community Meeting / Developer Hour: -- 1st and 3rd Monday of the month at 9 am PST -- 2nd and 4th Monday of the month at 5 pm PST - -Office Hours: -- Every Thursday at 8:30 am PST - -Meeting links can be found [here](https://discourse.llvm.org/t/new-community-meeting-developer-hour-schedule/73868). - ## Install torch-mlir snapshot At the time of writing, we release [pre-built snapshots of torch-mlir](https://github.com/llvm/torch-mlir-release) for Python 3.11 and Python 3.10. @@ -70,36 +50,36 @@ python -m pip install --upgrade pip Then, we can install torch-mlir with the corresponding torch and torchvision nightlies. ``` pip install --pre torch-mlir torchvision \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu -pip install torch-mlir -f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels + --extra-index-url https://download.pytorch.org/whl/nightly/cpu \ + -f https://github.com/llvm/torch-mlir-release/releases/expanded_assets/dev-wheels ``` -## Demos +## Using torch-mlir + +Torch-MLIR is primarily a project that is integrated into compilers to bridge them to PyTorch and ONNX. If contemplating a new integration, it may be helpful to refer to existing downstreams: -### TorchScript ResNet18 +* [IREE](https://github.com/iree-org/iree.git) +* [Blade](https://github.com/alibaba/BladeDISC) -Standalone script to Convert a PyTorch ResNet18 model to MLIR and run it on the CPU Backend: +While most of the project is exercised via testing paths, there are some ways that an end user can directly use the APIs without further integration: +### FxImporter ResNet18 ```shell # Get the latest example if you haven't checked out the code -wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/torchscript_resnet18.py +wget https://raw.githubusercontent.com/llvm/torch-mlir/main/projects/pt1/examples/fximporter_resnet18.py # Run ResNet18 as a standalone script. -python projects/pt1/examples/torchscript_resnet18.py +python projects/pt1/examples/fximporter_resnet18.py +# Output load image from https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg -Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/mlir/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth -100.0% +... PyTorch prediction -[('Labrador retriever', 70.66319274902344), ('golden retriever', 4.956596374511719), ('Chesapeake Bay retriever', 4.195662975311279)] +[('Labrador retriever', 70.65674591064453), ('golden retriever', 4.988346099853516), ('Saluki, gazelle hound', 4.477451324462891)] torch-mlir prediction -[('Labrador retriever', 70.66320037841797), ('golden retriever', 4.956601619720459), ('Chesapeake Bay retriever', 4.195651531219482)] +[('Labrador retriever', 70.6567153930664), ('golden retriever', 4.988325119018555), ('Saluki, gazelle hound', 4.477458477020264)] ``` -### Lazy Tensor Core - -View examples [here](docs/ltc_examples.md). - ## Repository Layout The project follows the conventions of typical MLIR-based projects: diff --git a/build-requirements.txt b/build-requirements.txt index 1566aa67606d..f45b51399ac2 100644 --- a/build-requirements.txt +++ b/build-requirements.txt @@ -5,6 +5,7 @@ setuptools cmake ninja packaging +nanobind>=2.4, <3.0 # Workaround for what should be a torch dep # See discussion in #1174 diff --git a/build_tools/autogen_ltc_backend.py b/build_tools/autogen_ltc_backend.py index 13753a6d5949..f18af385c41b 100644 --- a/build_tools/autogen_ltc_backend.py +++ b/build_tools/autogen_ltc_backend.py @@ -30,6 +30,12 @@ TORCHGEN_DIR = Path(torchgen.__path__[0]).resolve() TORCH_MLIR_DIR = Path(__file__).resolve().parent.parent +# Safely load fast C Yaml loader if it is are available +try: + from yaml import CSafeLoader as Loader +except ImportError: + from yaml import SafeLoader as Loader # type:ignore[assignment, misc] + def reindent(text, prefix=""): return indent(dedent(text), prefix) @@ -175,7 +181,7 @@ def generate_native_functions(self): ) ts_native_yaml = None if ts_native_yaml_path.exists(): - ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), yaml.CLoader) + ts_native_yaml = yaml.load(ts_native_yaml_path.read_text(), Loader) else: logging.warning( f"Could not find `ts_native_functions.yaml` at {ts_native_yaml_path}" @@ -208,7 +214,7 @@ def get_opnames(ops): ) with self.config_path.open() as f: - config = yaml.load(f, yaml.CLoader) + config = yaml.load(f, Loader) # List of unsupported ops in LTC autogen because of some error blacklist = set(config.get("blacklist", [])) diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh index fec5e252e8d7..c7d076c939a5 100755 --- a/build_tools/ci/build_posix.sh +++ b/build_tools/ci/build_posix.sh @@ -20,7 +20,7 @@ echo "Caching to ${cache_dir}" mkdir -p "${cache_dir}/ccache" mkdir -p "${cache_dir}/pip" -python="$(which python)" +python="$(which python3)" echo "Using python: $python" export CMAKE_TOOLCHAIN_FILE="$this_dir/linux_default_toolchain.cmake" @@ -40,7 +40,7 @@ echo "::group::CMake configure" cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ -GNinja \ -DCMAKE_BUILD_TYPE=Release \ - -DPython3_EXECUTABLE="$(which python)" \ + -DPython3_EXECUTABLE="$(which python3)" \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DTORCH_MLIR_ENABLE_WERROR_FLAG=ON \ -DCMAKE_INSTALL_PREFIX="$install_dir" \ @@ -50,7 +50,9 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ -DLLVM_TARGETS_TO_BUILD=host \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_LTC=ON + -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ + -DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON echo "::endgroup::" echo "::group::Build" diff --git a/build_tools/ci/install_python_deps.sh b/build_tools/ci/install_python_deps.sh index 6b49689ce8ea..375d0a6d1c24 100755 --- a/build_tools/ci/install_python_deps.sh +++ b/build_tools/ci/install_python_deps.sh @@ -7,7 +7,7 @@ repo_root="$(cd $this_dir/../.. && pwd)" torch_version="${1:-unknown}" echo "::group::installing llvm python deps" -python -m pip install --no-cache-dir -r $repo_root/externals/llvm-project/mlir/python/requirements.txt +python3 -m pip install --no-cache-dir -r $repo_root/externals/llvm-project/mlir/python/requirements.txt echo "::endgroup::" case $torch_version in @@ -19,7 +19,7 @@ case $torch_version in ;; stable) echo "::group::installing stable torch" - python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu + python3 -m pip install --no-cache-dir -r $repo_root/stable-requirements.txt python3 -m pip install --no-cache-dir -r $repo_root/build-requirements.txt echo "::endgroup::" ;; @@ -30,5 +30,5 @@ case $torch_version in esac echo "::group::installing test requirements" -python -m pip install --no-cache-dir -r $repo_root/test-requirements.txt +python3 -m pip install --no-cache-dir -r $repo_root/test-requirements.txt echo "::endgroup::" diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh index accdc41990c3..d5aa9d5ab79c 100755 --- a/build_tools/ci/test_posix.sh +++ b/build_tools/ci/test_posix.sh @@ -8,24 +8,12 @@ torch_version="${1:-unknown}" export PYTHONPATH="$repo_root/build/tools/torch-mlir/python_packages/torch_mlir:$repo_root/projects/pt1" -echo "::group::Run Linalg e2e integration tests" -python -m e2e_testing.main --config=linalg -v -echo "::endgroup::" - -echo "::group::Run make_fx + TOSA e2e integration tests" -python -m e2e_testing.main --config=make_fx_tosa -v -echo "::endgroup::" - -echo "::group::Run TOSA e2e integration tests" -python -m e2e_testing.main --config=tosa -v -echo "::endgroup::" - -echo "::group::Run Stablehlo e2e integration tests" -python -m e2e_testing.main --config=stablehlo -v +echo "::group::Run fx_importer_tosa e2e integration tests" +python -m e2e_testing.main --config=fx_importer_tosa -v echo "::endgroup::" echo "::group::Run ONNX e2e integration tests" -python -m e2e_testing.main --config=onnx -v +python3 -m e2e_testing.main --config=onnx -v echo "::endgroup::" case $torch_version in @@ -39,10 +27,19 @@ case $torch_version in # TODO: Need to verify in the stable version echo "::group::Run FxImporter e2e integration tests" - python -m e2e_testing.main --config=fx_importer -v + python3 -m e2e_testing.main --config=fx_importer -v echo "::endgroup::" + + # AMD: Disabled stablehlo. + # TODO: Need to verify in the stable version + # echo "::group::Run FxImporter2Stablehlo e2e integration tests" + # python3 -m e2e_testing.main --config=fx_importer_stablehlo -v + # echo "::endgroup::" ;; stable) + echo "::group::Run FxImporter e2e integration tests" + python -m e2e_testing.main --config=fx_importer -v + echo "::endgroup::" ;; *) echo "Unrecognized torch version '$torch_version' (specify 'nightly' or 'stable' with cl arg)" diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 625020836797..c2b97ff52866 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -16,7 +16,7 @@ # ./build_tools/python_deploy/build_linux_packages.sh # # Build specific Python versions and packages to custom directory: -# TM_PYTHON_VERSIONS="cp38-cp38 cp39-cp39" \ +# TM_PYTHON_VERSIONS="cp39-cp39 cp310-cp310" \ # TM_PACKAGES="torch-mlir" \ # TM_OUTPUT_DIR="/tmp/wheelhouse" \ # ./build_tools/python_deploy/build_linux_packages.sh @@ -46,11 +46,11 @@ TM_RELEASE_DOCKER_IMAGE="${TM_RELEASE_DOCKER_IMAGE:-quay.io/pypa/manylinux2014_$ # ./build_tools/docker/Dockerfile TM_CI_DOCKER_IMAGE="${TM_CI_DOCKER_IMAGE:-powderluv/torch-mlir-ci:latest}" # Version of Python to use in Release builds. Ignored in CIs. -TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp38-cp38 cp310-cp310 cp311-cp311}" +TM_PYTHON_VERSIONS="${TM_PYTHON_VERSIONS:-cp310-cp310 cp311-cp311 cp312-cp312}" # Location to store Release wheels TM_OUTPUT_DIR="${TM_OUTPUT_DIR:-${this_dir}/wheelhouse}" # What "packages to build" -TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-core}" +TM_PACKAGES="${TM_PACKAGES:-torch-mlir torch-mlir-ext}" # Use pre-built Pytorch TM_USE_PYTORCH_BINARY="${TM_USE_PYTORCH_BINARY:-ON}" # Skip running tests if you want quick iteration @@ -83,12 +83,12 @@ function run_on_host() { fi mkdir -p "${TM_OUTPUT_DIR}" case "$package" in - torch-mlir) + torch-mlir-ext) TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE} export USERID=0 export GROUPID=0 ;; - torch-mlir-core) + torch-mlir) TM_CURRENT_DOCKER_IMAGE=${TM_RELEASE_DOCKER_IMAGE} export USERID=0 export GROUPID=0 @@ -158,22 +158,22 @@ function run_in_docker() { export PATH=$python_dir/bin:$orig_path echo ":::: Python version $(python3 --version)" case "$package" in - torch-mlir) - clean_wheels torch_mlir "$python_version" - build_torch_mlir "$TM_TORCH_VERSION" + torch-mlir-ext) + clean_wheels torch_mlir_ext "$python_version" + build_torch_mlir_ext "$TM_TORCH_VERSION" # Disable audit wheel until we can fix ODR torch issues. See # https://github.com/llvm/torch-mlir/issues/1709 # - #run_audit_wheel torch_mlir "$python_version" + #run_audit_wheel torch_mlir_ext "$python_version" - clean_build torch_mlir "$python_version" + clean_build torch_mlir_ext "$python_version" ;; - torch-mlir-core) - clean_wheels torch_mlir_core "$python_version" - build_torch_mlir_core - run_audit_wheel torch_mlir_core "$python_version" - clean_build torch_mlir_core "$python_version" + torch-mlir) + clean_wheels torch_mlir "$python_version" + build_torch_mlir + run_audit_wheel torch_mlir "$python_version" + clean_build torch_mlir "$python_version" ;; out-of-tree) setup_venv "$python_version" "$TM_TORCH_VERSION" @@ -324,9 +324,6 @@ function test_in_tree() { ;; esac - echo ":::: Run make_fx + TOSA e2e integration tests" - python -m e2e_testing.main --config=make_fx_tosa -v - echo ":::: Run TOSA e2e integration tests" python -m e2e_testing.main --config=tosa -v } @@ -350,7 +347,7 @@ function setup_venv() { ;; stable) echo ":::: Using stable dependencies" - python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu + python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/stable-requirements.txt python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt ;; *) @@ -431,7 +428,7 @@ function clean_build() { rm -rf /main_checkout/torch-mlir/build /main_checkout/torch-mlir/llvm-build /main_checkout/torch-mlir/docker_venv /main_checkout/torch-mlir/libtorch } -function build_torch_mlir() { +function build_torch_mlir_ext() { # Disable LTC build for releases export TORCH_MLIR_ENABLE_LTC=0 local torch_version="$1" @@ -439,16 +436,16 @@ function build_torch_mlir() { nightly) echo ":::: Using nightly dependencies" python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/requirements.txt \ - --extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + --extra-index-url https://download.pytorch.org/whl/nightly/cpu/torch/ CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ python -m pip wheel -v --no-build-isolation -w /wheelhouse /main_checkout/torch-mlir \ - -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html \ + -f https://download.pytorch.org/whl/nightly/cpu/torch/ \ -r /main_checkout/torch-mlir/whl-requirements.txt ;; stable) echo ":::: Using stable dependencies" - python3 -m pip install --no-cache-dir torch torchvision + python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/stable-requirements.txt python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ @@ -470,7 +467,9 @@ function run_audit_wheel() { rm "$generic_wheel" } -function build_torch_mlir_core() { +function build_torch_mlir() { + # Disable LTC build for releases + export TORCH_MLIR_ENABLE_LTC=0 python -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt CMAKE_GENERATOR=Ninja \ TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ diff --git a/build_tools/python_deploy/build_macos_packages.sh b/build_tools/python_deploy/build_macos_packages.sh index b928c1e48cf6..5b4b2031cdc5 100755 --- a/build_tools/python_deploy/build_macos_packages.sh +++ b/build_tools/python_deploy/build_macos_packages.sh @@ -6,7 +6,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # build_macos_packages.sh -# One stop build of IREE Python packages for MacOS. This presumes that +# One stop build of torch-mlir Python packages for MacOS. This presumes that # dependencies are installed from install_macos_deps.sh. This will build # for a list of Python versions synchronized with that script and corresponding # with directory names under: @@ -30,7 +30,7 @@ echo "Setting torch-mlir Python Package version to: ${TORCH_MLIR_PYTHON_PACKAGE_ # Note that this typically is selected to match the version that the official # Python distributed is built at. -export MACOSX_DEPLOYMENT_TARGET="${TORCH_MLIR_OSX_TARGET:-11.0}" +export MACOSX_DEPLOYMENT_TARGET="${TORCH_MLIR_OSX_TARGET:-11.1}" export CMAKE_OSX_ARCHITECTURES="${TORCH_MLIR_OSX_ARCH:-arm64;x86_64}" echo "CMAKE_OSX_ARCHITECTURES: $CMAKE_OSX_ARCHITECTURES" echo "MACOSX_DEPLOYMENT_TARGET $MACOSX_DEPLOYMENT_TARGET" @@ -56,16 +56,16 @@ function run() { export PATH=$python_dir/bin:$orig_path echo ":::: Python version $(python3 --version)" case "$package" in + torch-mlir-ext) + clean_wheels torch_mlir_ext "$python_version" + build_torch_mlir_ext torch_mlir_ext "$python_version" + run_audit_wheel torch_mlir_ext "$python_version" + ;; torch-mlir) clean_wheels torch_mlir "$python_version" build_torch_mlir torch_mlir "$python_version" run_audit_wheel torch_mlir "$python_version" ;; - torch-mlir-core) - clean_wheels torch_mlir_core "$python_version" - build_torch_mlir_core torch_mlir_core "$python_version" - run_audit_wheel torch_mlir_core "$python_version" - ;; *) echo "Unrecognized package '$package'" exit 1 @@ -75,7 +75,7 @@ function run() { done } -function build_torch_mlir() { +function build_torch_mlir_ext() { local wheel_basename="$1" local python_version="$2" rm -rf "$output_dir"/build_venv @@ -88,12 +88,12 @@ function build_torch_mlir() { TORCH_MLIR_PYTHON_PACKAGE_VERSION=${TORCH_MLIR_PYTHON_PACKAGE_VERSION} \ MACOSX_DEPLOYMENT_TARGET=$MACOSX_DEPLOYMENT_TARGET \ CMAKE_OSX_ARCHITECTURES=$CMAKE_OSX_ARCHITECTURES \ - python"${python_version}" -m pip wheel -v -w "$output_dir" "$repo_root" --extra-index-url https://download.pytorch.org/whl/nightly/cpu + python"${python_version}" -m pip wheel -v --no-build-isolation -w "$output_dir" "$repo_root" --extra-index-url https://download.pytorch.org/whl/nightly/cpu deactivate rm -rf "$output_dir"/build_venv } -function build_torch_mlir_core() { +function build_torch_mlir() { local wheel_basename="$1" local python_version="$2" rm -rf "$output_dir"/build_venv @@ -107,7 +107,7 @@ function build_torch_mlir_core() { CMAKE_OSX_ARCHITECTURES=$CMAKE_OSX_ARCHITECTURES \ TORCH_MLIR_ENABLE_JIT_IR_IMPORTER=0 \ TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1 \ - python"${python_version}" -m pip wheel -v -w "$output_dir" "$repo_root" + python"${python_version}" -m pip wheel -v --no-build-isolation -w "$output_dir" "$repo_root" deactivate rm -rf "$output_dir"/build_venv } diff --git a/build_tools/python_deploy/build_windows.ps1 b/build_tools/python_deploy/build_windows.ps1 index 808a16cb18e7..bc829a87d6d3 100644 --- a/build_tools/python_deploy/build_windows.ps1 +++ b/build_tools/python_deploy/build_windows.ps1 @@ -21,7 +21,7 @@ Write-Host "Build Deps installation completed successfully" Write-Host "Building torch-mlir" $env:CMAKE_GENERATOR='Ninja' $env:TORCH_MLIR_ENABLE_LTC='0' -python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html -r whl-requirements.txt +python -m pip wheel -v -w wheelhouse ./ -f https://download.pytorch.org/whl/nightly/cpu/torch/ -r whl-requirements.txt Write-Host "Build completed successfully" diff --git a/build_tools/python_deploy/build_windows_ci.sh b/build_tools/python_deploy/build_windows_ci.sh index c5da1adf6cae..2e1648679c57 100644 --- a/build_tools/python_deploy/build_windows_ci.sh +++ b/build_tools/python_deploy/build_windows_ci.sh @@ -14,6 +14,7 @@ cmake -GNinja -Bbuild \ -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \ -DPython3_EXECUTABLE="$(which python)" \ + -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ $GITHUB_WORKSPACE/externals/llvm-project/llvm cmake --build build --config Release diff --git a/build_tools/python_deploy/install_macos_deps.sh b/build_tools/python_deploy/install_macos_deps.sh index 4d91a244c75f..32b4b294ca51 100755 --- a/build_tools/python_deploy/install_macos_deps.sh +++ b/build_tools/python_deploy/install_macos_deps.sh @@ -19,14 +19,14 @@ if [[ "$(whoami)" != "root" ]]; then fi PYTHON_INSTALLER_URLS=( - "https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg" - "https://www.python.org/ftp/python/3.10.10/python-3.10.10-macos11.pkg" + "https://www.python.org/ftp/python/3.11.9/python-3.11.9-macos11.pkg" + "https://www.python.org/ftp/python/3.10.11/python-3.10.11-macos11.pkg" "https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg" ) PYTHON_SPECS=( - 3.11@https://www.python.org/ftp/python/3.11.2/python-3.11.2-macos11.pkg - 3.10@https://www.python.org/ftp/python/3.10.5/python-3.10.5-macos11.pkg + 3.11@https://www.python.org/ftp/python/3.11.9/python-3.11.9-macos11.pkg + 3.10@https://www.python.org/ftp/python/3.10.11/python-3.10.11-macos11.pkg 3.9@https://www.python.org/ftp/python/3.9.13/python-3.9.13-macos11.pkg ) diff --git a/build_tools/update_abstract_interp_lib.sh b/build_tools/update_abstract_interp_lib.sh index cb44a4e8b27c..4da20c3e715a 100755 --- a/build_tools/update_abstract_interp_lib.sh +++ b/build_tools/update_abstract_interp_lib.sh @@ -41,7 +41,10 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then ext_module="${TORCH_MLIR_EXT_MODULES} " fi -PYTHONPATH="${pypath}" python \ +# To enable this python package, manually build torch_mlir with: +# -DTORCH_MLIR_ENABLE_JIT_IR_IMPORTER=ON +# TODO: move this package out of JIT_IR_IMPORTER. +PYTHONPATH="${pypath}" python3 \ -m torch_mlir.jit_ir_importer.build_tools.abstract_interp_lib_gen \ --pytorch_op_extensions=${ext_module:-""} \ --torch_transforms_cpp_dir="${torch_transforms_cpp_dir}" diff --git a/build_tools/update_torch_ods.sh b/build_tools/update_torch_ods.sh index cb0599f16f10..efe0055d7e06 100755 --- a/build_tools/update_torch_ods.sh +++ b/build_tools/update_torch_ods.sh @@ -42,7 +42,10 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then fi set +u -PYTHONPATH="${PYTHONPATH}:${pypath}" python \ +# To enable this python package, manually build torch_mlir with: +# -DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON +# TODO: move this package out of JIT_IR_IMPORTER. +PYTHONPATH="${PYTHONPATH}:${pypath}" python3 \ -m torch_mlir.jit_ir_importer.build_tools.torch_ods_gen \ --torch_ir_include_dir="${torch_ir_include_dir}" \ --pytorch_op_extensions="${ext_module}" \ diff --git a/docs/add_ops.md b/docs/add_ops.md index b8e5ce37ec45..aea7f3ab6ed7 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -2,71 +2,49 @@ Collected links and contacts for how to add ops to torch-mlir. -
-Turbine Camp: Start Here -This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. +## [How to Add a Torch Operator](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md) -Written & maintained by @renxida - -Guides by other folks that were used during the creation of this document: -- [Chi Liu](https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2) -- [Sunsoon](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) - -## Before you begin... - -Nod-ai maintains the pipeline below, which allows us to take a ML model from e.g. huggingface, and compile it to a variety of devices including llvm-cpu, rocm and cuda and more as an optimized `vmfb` binary. - -1. The pipeline begins with a huggingface model, or some other supported source like llama.cpp. -2. [nod-ai/SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) takes a huggingface model and exports a `.mlir` file. -3. **[llvm/torch-mlir](https://github.com/llvm/torch-mlir)**, which you will be working on in turbine-camp, will lower torchscript, torch dialect, and torch aten ops further into a mixture `linalg` or `math` MLIR dialects (with occasionally other dialects in the mix) -4. [IREE](https://github.com/openxla/iree) converts the final `.mlir` file into a binary (typically `.vmfb`) for running on a device (llvm-cpu, rocm, vulcan, cuda, etc). - -The details of how we do it and helpful commands to help you set up each repo is in [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) - -PS: IREE is pronounced Eerie, and hence the ghost icon. - -## How to begin -0. Set up torch-mlir according to the instructions here: https://github.com/llvm/torch-mlir/blob/main/docs/development.md -1. You will start by adding support for 2 ops in torch-mlir, to get you familiar with the center of our pipeline. Begin by reading [torch-mlir's documentation on how to implement a new torch op](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md), and set up `llvm/torch_mlir` using https://github.com/llvm/torch-mlir/blob/main/docs/development.md -2. Pick 1 of the yet-unimplemented from the following. You should choose something that looks easy to you. **Make sure you create an issue by clicking the little "target" icon to the right of the op, thereby marking the op as yours** - - [TorchToLinalg ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/347) - - [TorchOnnnxToTorch ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/215) -3. Implement it. For torch -> linalg, see the how to torchop section below. For Onnx ops, see how to onnx below. -5. Make a pull request and reference your issue. When the pull request is closed, also close your issue to mark the op as done - -
+## How to Add a Conversion for an Operator ### How to TorchToLinalg -You will need to do 4 things: +You will need to do 5 things: + +- make sure -DTORCH_MLIR_ENABLE_JIT_IR_IMPORTER=ON is added during build. This is to enable the python file used in `build_tools/update_torch_ods.sh` and `build_tools/update_abstract_interp_lib.sh` - make sure the op exists in `torch_ods_gen.py`, and then run `build_tools/update_torch_ods.sh`, and then build. This generates `GeneratedTorchOps.td`, which is used to generate the cpp and h files where ops function signatures are defined. - - Reference [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) + - Reference [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) - make sure the op exists in `abstract_interp_lib_gen.py`, and then run `build_tools/update_abstract_interp_lib.sh`, and then build. This generates `AbstractInterpLib.cpp`, which is used to generate the cpp and h files where ops function signatures are defined. - - Reference [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) + - Reference [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) - write test cases. They live in `projects/pt1`. See the [Dec 2023 example](https://github.com/llvm/torch-mlir/pull/2640/files). - implement the op in one of the `lib/Conversion/TorchToLinalg/*.cpp` files Reference Examples + - [A Dec 2023 example with the most up to date lowering](https://github.com/llvm/torch-mlir/pull/2640/files) - [Chi's simple example of adding op lowering](https://github.com/llvm/torch-mlir/pull/1454) useful instructions and referring links for you to understand the op lowering pipeline in torch-mlir in the comments Resources: -- how to set up torch-mlir: [https://github.com/llvm/torch-mlir/blob/main/docs/development.md](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#checkout-and-build-from-source) -- torch-mlir doc on how to debug and test: [ttps://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing) + +- [how to set up torch-mlir](https://github.com/llvm/torch-mlir/blob/main/docs/development.md) +- [torch-mlir doc on how to debug and test](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing) - [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) - [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) ### How to TorchOnnxToTorch -0. Generate the big folder of ONNX IR. Use https://github.com/llvm/torch-mlir/blob/main/test/python/onnx_importer/import_smoke_test.py . Alternatively, if you're trying to support a certain model, convert that model to onnx IR with - ``` + +1. Generate the big folder of ONNX IR. Use [this Python script](https://github.com/llvm/torch-mlir/blob/main/test/python/onnx_importer/import_smoke_test.py). Alternatively, if you're trying to support a certain model, convert that model to onnx IR with + + ```shell optimum-cli export onnx --model facebook/opt-125M fb-opt python -m torch_mlir.tools.import_onnx fb-opt/model.onnx -o fb-opt-125m.onnx.mlir ``` -2. Find an instance of the Op that you're trying to implement inside the smoke tests folder or the generated model IR, and write a test case. Later you will save it to one of the files in `torch-mlir/test/Conversion/TorchOnnxToTorch`, but for now feel free to put it anywhere. -3. Implement the op in `lib/Conversion/TorchOnnxToTorch/something.cpp`. -4. Test the conversion by running `./build/bin/torch-mlir-opt -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch your_mlir_file.mlir`. For more details, see https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing . Xida usually creates a separate MLIR file to test it to his satisfaction before integrating it into one of the files at `torch-mlir/test/Conversion/TorchOnnxToTorch`. + +1. Find an instance of the Op that you're trying to implement inside the smoke tests folder or the generated model IR, and write a test case. Later you will save it to one of the files in `torch-mlir/test/Conversion/TorchOnnxToTorch`, but for now feel free to put it anywhere. +1. Implement the op in `lib/Conversion/TorchOnnxToTorch/something.cpp`. +1. Test the conversion by running `./build/bin/torch-mlir-opt -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch your_mlir_file.mlir`. For more details, see [the testing section of the doc on development](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing). Xida usually creates a separate MLIR file to test it to his satisfaction before integrating it into one of the files at `torch-mlir/test/Conversion/TorchOnnxToTorch`. Helpful examples: + - [A Dec 2023 example where an ONNX op is implemented](https://github.com/llvm/torch-mlir/pull/2641/files#diff-b584b152020af6d2e5dbf62a08b2f25ed5afc2c299228383b9651d22d44b5af4R493) - [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) @@ -75,17 +53,9 @@ Helpful examples: - Generate FILECHECK tests from MLIR test cases: `torch-mlir-opt -convert- /tmp/your_awesome_testcase.mlir | externals/llvm-project/mlir/utils/generate-test-checks.py `. Please don't just paste the generated tests - reference them to write your own -## Contacts -People who've worked on this for a while -- Vivek (@vivek97 on discord) -- Chi.Liu@amd.com - -Recent Turbine Camp Attendees, from recent to less recent -- Xida.ren@amd.com (@xida_ren on discord) -- Sungsoon.Cho@amd.com - ## Links -- IMPORTANT: read the LLVM style guide: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code + +- IMPORTANT: read [the LLVM style guide](https://llvm.org/docs/CodingStandards.html#style-issues) - Tutorials - [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) - This document contains commands that would help you set up shark and run demos @@ -98,24 +68,18 @@ Recent Turbine Camp Attendees, from recent to less recent - If you have questions, reach out to [Chi on Discord](https://discordapp.com/channels/973663919757492264/1104195883307892837/1180233875058868224) - [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) - Find Ops To Lower - - [Torch MLIR + ONNX Unimplemented Ops on Sharepoint](https://amdcloud-my.sharepoint.com/:x:/r/personal/esaimana_amd_com/Documents/Torch%20MLIR%20+%20ONNX%20Unimplemented%20Ops.xlsx?d=w438f26fac8fd44eeafb89bc99e2c563b&csf=1&web=1&e=Qd4eHm) + - Torch MLIR + ONNX Unimplemented Ops on Sharepoint ( see SharePoint: esaimana_amd_com/Documents/Torch%20MLIR%20+%20ONNX%20Unimplemented%20Ops.xlsx?d=w438f26fac8fd44eeafb89bc99e2c563b&csf=1&web=1&e=Qd4eHm) - If you don't have access yet, request it. - nod-ai/SHARK-Turbine ssues tracking op support - [Model and Op Support](https://github.com/nod-ai/SHARK-Turbine/issues/119) - [ONNX op support](https://github.com/nod-ai/SHARK-Turbine/issues/215) +## [Chi's useful commands for debugging torch mlir](https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2) -## Chi's useful commands for debugging torch mlir - -https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2 - -## How to write test cases and test your new op - -https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing - - +## [How to write test cases and test your new op](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing) ## How to set up vs code and intellisence for [torch-mlir] + Xida: This is optional. If you're using VS code like me, you might want to set it up so you can use the jump to definition / references, auto fix, and other features. Feel free to contact me on discord if you have trouble figuring this out. @@ -161,4 +125,5 @@ under `torch-mlir` "cmake.cmakePath": "/home/xida/miniconda/envs/torch-mlir/bin/cmake", // make sure this is a cmake that knows where your python is } ``` + The important things to note are the `cmake.configureArgs`, which specify the location of your torch mlir, and the `cmake.sourceDirectory`, which indicates that CMAKE should not build from the current directory and should instead build from `externals/llvm-project/llvm` diff --git a/docs/architecture.md b/docs/architecture.md index e2ef378bd99c..1c8752092549 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -62,7 +62,7 @@ program representations can eventually bottom-out on the JIT IR via some path provided by PyTorch. The `torch` dialect is almost entirely in 1:1 correspondence with the JIT IR -- this allows the importer to be extremely small (the core is -[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/jit_ir_importer/csrc/node_importer.cpp#L1)). +[under 500 lines of code](https://github.com/llvm/torch-mlir/blob/e322f6a8784009b37aa354abfa9a40a80f30877d/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp)). ### Ops diff --git a/docs/development.md b/docs/development.md index fe997447c319..61c60a646a5c 100644 --- a/docs/development.md +++ b/docs/development.md @@ -14,7 +14,7 @@ While this is running, you can already setup the Python venv and dependencies in ## Setup your Python VirtualEnvironment and Dependencies ```shell -python -m venv mlir_venv +python3 -m venv mlir_venv source mlir_venv/bin/activate # Some older pip installs may not be able to handle the recent PyTorch deps python -m pip install --upgrade pip @@ -53,42 +53,52 @@ Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. On Windows, use the "Developer PowerShell for Visual Studio" to ensure that the compiler and linker binaries are in the `PATH` variable. +This requires `lld`, `clang`, `ccache`, and other dependencies for building `libtorch` / `PyTorch` wheels from source. If you run into issues because of these, try the [simplified build command](#simplified-build). + ```shell cmake -GNinja -Bbuild \ + externals/llvm-project/llvm \ -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DPython3_FIND_VIRTUALENV=ONLY \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DLLVM_TARGETS_TO_BUILD=host \ - externals/llvm-project/llvm -``` -#### Flags that can reduce build time: -* Enabling clang on Linux -```shell - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -``` -* Enabling ccache -```shell - -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -``` -* Enabling LLD (links in seconds compared to minutes) -```shell - -DCMAKE_EXE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_MODULE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_SHARED_LINKER_FLAGS_INIT="-fuse-ld=lld" -# Use --ld-path= instead of -fuse-ld=lld for clang > 13 -``` -* Enabling libtorch binary cache -By default we download the latest version of libtorch everytime you build so we are always on the latest version. Set `-DLIBTORCH_CACHE=ON` to -not download the latest version everytime. If libtorch gets out of date and you test against a newer PyTorch you may notice failures. -```shell - -DLIBTORCH_CACHE=ON -``` -* Enabling building libtorch as part of your build -By default we download the latest version of libtorch. We have an experimental path to build libtorch (and PyTorch wheels) from source. + `# use clang`\ + -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ + `# use ccache to cache build results` \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + `# use LLD to link in seconds, rather than minutes` \ + `# if using clang <= 13, replace --ld-path=ld.lld with -fuse-ld=lld` \ + -DCMAKE_EXE_LINKER_FLAGS_INIT="--ld-path=ld.lld" \ + -DCMAKE_MODULE_LINKER_FLAGS_INIT="--ld-path=ld.lld" \ + -DCMAKE_SHARED_LINKER_FLAGS_INIT="--ld-path=ld.lld" \ + `# Enabling libtorch binary cache instead of downloading the latest libtorch everytime.` \ + `# Testing against a mismatched version of libtorch may cause failures` \ + -DLIBTORCH_CACHE=ON \ + `# Enable an experimental path to build libtorch (and PyTorch wheels) from source,` \ + `# instead of downloading them` \ + -DLIBTORCH_SRC_BUILD=ON \ + `# Set the variant of libtorch to build / link against. (shared|static and optionally cxxabi11)` \ + -DLIBTORCH_VARIANT=shared +``` + +# Simplified build + +If you're running into issues with the above build command, consider using the following: + ```shell - -DLIBTORCH_SRC_BUILD=ON # Build Libtorch from source - -DLIBTORCH_VARIANT=shared # Set the variant of libtorch to build / link against. (`shared`|`static` and optionally `cxxabi11`) +cmake -GNinja -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_TARGETS_TO_BUILD=host \ + externals/llvm-project/llvm ``` #### Flags to enable MLIR debugging: @@ -99,6 +109,15 @@ By default we download the latest version of libtorch. We have an experimental p -DLLVM_ENABLE_ASSERTIONS=ON \ ``` +#### Flags to run end-to-end tests: + +Running the end-to-end execution tests locally requires enabling the native PyTorch extension features and the JIT IR importer, which depends on the +former and defaults to `ON` if not changed: +```shell + -DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON \ + -DTORCH_MLIR_ENABLE_JIT_IR_IMPORTER=ON \ +``` + ### Building against a pre-built LLVM If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`, you can also build the project *out-of-tree* using the following command as template: @@ -330,9 +349,9 @@ The following additional environmental variables can be used to customize your d ``` * Custom Python Versions for Release builds: - Version of Python to use in Release builds. Ignored in CIs. Defaults to `cp38-cp38 cp39-cp39 cp310-cp310` + Version of Python to use in Release builds. Ignored in CIs. Defaults to `cp39-cp39 cp310-cp310 cp312-cp312` ```shell - TM_PYTHON_VERSIONS="cp38-cp38 cp39-cp39 cp310-cp310" + TM_PYTHON_VERSIONS="cp39-cp39 cp310-cp310 cp312-cp312" ``` * Location to store Release build wheels @@ -386,6 +405,8 @@ Torch-MLIR has two types of tests: a homegrown testing framework (see `projects/pt1/python/torch_mlir_e2e_test/framework.py`) and the test suite lives at `projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py`. + The tests require to build with `TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` (and + the dependent option `TORCH_MLIR_ENABLE_JIT_IR_IMPORTER`) set to `ON`. 2. Compiler and Python API unit tests. These use LLVM's `lit` testing framework. For example, these might involve using `torch-mlir-opt` to run a pass and @@ -419,6 +440,20 @@ cd projects/pt1 python -m e2e_testing.main -f 'AtenEmbeddingBag' ``` +The default mode of running tests uses the multi-processing framework and is +not tolerant of certain types of errors. If encountering native crashes/hangs, +enable debug variables to run sequentially/in-process with more verbosity: + +``` +export TORCH_MLIR_TEST_CONCURRENCY=1 +export TORCH_MLIR_TEST_VERBOSE=1 +``` + +In this way, you can run under `gdb`, etc and get useful results. Having env +vars like this makes it easy to set in GH action files, etc. Note that the +verbose flags are very verbose. Basic sequential progress reports will be +printed regardless when not running in parallel. + ## Running unit tests. To run all of the unit tests, run: diff --git a/docs/ltc_examples.md b/docs/ltc_examples.md index 217761a51ebd..0c6da17bbf2a 100644 --- a/docs/ltc_examples.md +++ b/docs/ltc_examples.md @@ -51,4 +51,4 @@ In Mark Step: true ``` ## Example Models -There are also examples of a [HuggingFace BERT](../examples/ltc_backend_bert.py) and [MNIST](../examples/ltc_backend_mnist.py) model running on the example LTC backend. +There are also examples of a [HuggingFace BERT](../projects/pt1/examples/ltc_backend_bert.py) and [MNIST](../projects/pt1/examples/ltc_backend_mnist.py) model running on the example LTC backend. diff --git a/externals/llvm-project b/externals/llvm-project index dabdec1001dc..42131eee8342 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit dabdec1001dc368373dd581cf72f37a440873ce3 +Subproject commit 42131eee834229f457f62d39f2a31134a86dea9b diff --git a/externals/stablehlo b/externals/stablehlo index ab92adeda911..b62dc66da994 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit ab92adeda9119a6c3914cd42367b0a2b70765e91 +Subproject commit b62dc66da9946b4c400c0d99c9d5bb8e04edaee6 diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index b214e147d5d9..dd7cfb5c428f 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -220,6 +220,19 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); /// Gets the !torch.quint8 typeid. MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void); +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.qint16 type +MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt16(MlirType t); + +/// Gets the !torch.qint16 type. +MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt16TypeGet(MlirContext context); + +/// Gets the !torch.qint16 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt16TypeGetTypeID(void); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td index 8e7be05e198c..3dce86149fa8 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.td @@ -125,7 +125,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getInputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -144,7 +144,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getInputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -200,7 +200,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getOutputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -219,7 +219,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::copy_if(getOutputOperands(), std::back_inserter(result), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); return result; }] @@ -238,7 +238,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::transform(getOutputBufferOperands(), std::back_inserter(result), [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); + return cast(opOperands->get().getType()); }); return result; }] @@ -257,7 +257,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { llvm::transform(getOutputTensorOperands(), std::back_inserter(result), [](OpOperand *opOperands) { - return opOperands->get().getType().cast(); + return cast(opOperands->get().getType()); }); return result; }] @@ -318,7 +318,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) + if (!isa(opOperand->get().getType())) return false; if (opOperand->getOperandNumber() < $_op.getNumInputs()) return true; @@ -334,7 +334,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*args=*/(ins "OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ - if (!opOperand->get().getType().template isa()) + if (!isa(opOperand->get().getType())) return false; if (opOperand->getOperandNumber() >= $_op.getNumInputs()) return true; @@ -367,7 +367,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + dyn_cast(opOperand->get().getType())) return shapedType.getRank(); return 0; }] @@ -383,7 +383,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); if (auto shapedType = - opOperand->get().getType().template dyn_cast()) + dyn_cast(opOperand->get().getType())) return shapedType.getShape(); return {}; }] @@ -398,7 +398,7 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == this->getOperation()); - return !opOperand->get().getType().template isa(); + return !isa(opOperand->get().getType()); }] >, //===------------------------------------------------------------------===// @@ -416,10 +416,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { return this->getOperation()->getNumResults() == 0 && llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { return isScalar(opOperand) || - opOperand->get().getType().template isa(); + isa(opOperand->get().getType()); }) && llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); }] >, @@ -435,10 +435,10 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { return llvm::all_of(getInputOperands(), [&](OpOperand *opOperand) { return isScalar(opOperand) || - opOperand->get().getType().template isa(); + isa(opOperand->get().getType()); }) && llvm::all_of(getOutputOperands(), [](OpOperand *opOperand) { - return opOperand->get().getType().template isa(); + return isa(opOperand->get().getType()); }); }] >, @@ -478,8 +478,8 @@ def TMTensorInterface : OpInterface<"TMTensorOp"> { private: void setOperandSegmentAt(unsigned idx, unsigned val) { - auto attr = (*this)->getAttr("operand_segment_sizes") - .cast(); + auto attr = cast((*this)->getAttr("operand_segment_sizes") + ); unsigned i = 0; auto newAttr = attr.mapValues(IntegerType::get(getContext(), 32), [&](const APInt &v) { return (i++ == idx) ? APInt(32, val) : v; }); diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index 12a74faa44d3..c47eaabf7364 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -88,7 +88,7 @@ def TMTensor_ScanOp : TMTensor_Op<"scan", return getOutputOperand(0)->get(); } ShapedType getOperandType() { - return input().getType().cast(); + return cast(input().getType()); } int64_t getOperandRank() { return getOperandType().getRank(); @@ -151,10 +151,10 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{ int64_t getIndexDepth() { - return getInputOperand(1) + return cast(getInputOperand(1) ->get() .getType() - .cast() + ) .getShape() .back(); } @@ -164,7 +164,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", } ShapedType getUpdateType() { - return updates().getType().cast(); + return cast(updates().getType()); } Value indices() { @@ -172,7 +172,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", } ShapedType getIndicesType() { - return indices().getType().cast(); + return cast(indices().getType()); } Value original() { @@ -180,11 +180,11 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", } ShapedType getOriginalType() { - return original().getType().cast(); + return cast(original().getType()); } int64_t getUpdateSliceRank() { - return updates().getType().cast().getRank() - 1; + return cast(updates().getType()).getRank() - 1; } bool isScalarUpdate() { @@ -224,7 +224,7 @@ def TMTensor_SortOp : TMTensor_Op<"sort", return getOutputs()[index]; } ShapedType getOperandType(int index) { - return operand(index).getType().cast(); + return cast(operand(index).getType()); } int64_t getOperandRank() { return getOperandType(0).getRank(); @@ -252,13 +252,14 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", ["generateScalarImplementation"]>]> { let summary = "Attention operator"; let description = [{ - This operator takes in 3 tensors: query(Q), key(K) and value(V) and computes - the attention. Each of the inputs has shape BxNxd where B is the - of the batch dimension, N is the sequence length and d is head dimension. - Typically N >>> d. Mathematically, the attention is defined as - matmul(softmax(matmul(Q, transpose(K))), V) and has shape BxNxd. Usually, - this operator also performs scaling, masking and dropout, but we leave - that out of the current implementation. + This operator takes in 3 to 4 tensors: query(Q), key(K), value(V), and an + optional mask(M) to compute the attention. These tensors must take on shapes + BxMxK1 for Q, BxK2xK1 for K, BxK2xN for V, and BxMxK2 for M. For all these + shapes, B represents the batch dimension, M represents sequence length, N + represents head dimension, and K1 and K2 are hidden dimensions. + Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and + has shape BxMxN. Usually, this operator also performs scaling, masking and + dropout, but we leave that out of the current implementation. }]; let arguments = (ins Variadic:$inputs, @@ -287,20 +288,32 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", Value getValue() { return getInputOperand(2)->get(); } + std::optional getAttnMask() { + if (getNumInputs() < 4) { + return std::nullopt; + } + return getInputOperand(3)->get(); + } Value getOutput() { return getOutputOperand(0)->get(); } ShapedType getQueryType() { - return getQuery().getType().cast(); + return cast(getQuery().getType()); } ShapedType getKeyType() { - return getKey().getType().cast(); + return cast(getKey().getType()); } ShapedType getValueType() { - return getValue().getType().cast(); + return cast(getValue().getType()); + } + std::optional getAttnMaskType() { + if (getAttnMask()){ + return cast((*getAttnMask()).getType()); + } + return std::nullopt; } ShapedType getOutputType() { - return getOutput().getType().cast(); + return cast(getOutput().getType()); } int64_t getQueryRank() { return getQueryType().getRank(); @@ -311,6 +324,12 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", int64_t getValueRank() { return getValueType().getRank(); } + std::optional getAttnMaskRank() { + if (getAttnMask()){ + return (*getAttnMaskType()).getRank(); + } + return std::nullopt; + } int64_t getOutputRank() { return getOutputType().getRank(); } @@ -326,6 +345,81 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", }]; } +def TMTensor_TopkOp : TMTensor_Op<"topk", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Top-K operator"; + let description = [{ + A Top-K operation for N-D tensors. Reduces the target dimension from the input + size N down to K elements based on the supplied binary region. + + Accepts an N-D tensor input consisting of values and an optioanl N-D tensor + for indices of those values (i32 type). If input indices aren't provided, the + index mapping is inferred based on the k dim. Both input values/indices + tensors and output values/indicies tensors must have the same shape. Top-K is + computed along the target dimension (from dimension()). Returns two output + tensors of values and the indicies of Top-K results. The output dimensions + must match the input save for the dimension that is reduced to K results. + + Region accepts lhs=[next N input] and rhs=[exiting K output] and yeilds an + i1. If true, the two values are swapped: + - For Top-K compoarision: > + - For Min-K comparision: < + Note: when the two values are equal, the first occurence is always selected. + }]; + + let arguments = (ins Variadic:$inputs, + Variadic:$outputs, + I64Attr:$dimension + ); + + let results = (outs Variadic:$results); + let regions = (region AnyRegion:$region); + let assemblyFormat = [{ + attr-dict + `dimension` `(` $dimension `)` + `ins` `(` $inputs `:` type($inputs) `)` + `outs` `(` $outputs `:` type($outputs) `)` + $region (`->` type($results)^)? + }]; + + let extraClassDeclaration = extraTMTensorOpClassDeclaration # [{ + Value values() { + return getInputOperand(0)->get(); + } + std::optional indices() { + if (getNumInputs() < 2) { + return {}; + } + return getInputOperand(1)->get(); + } + Value outputValues() { + return getOutputOperand(0)->get(); + } + Value outputIndices() { + return getOutputOperand(1)->get(); + } + ShapedType getInputType() { + return cast(values().getType()); + } + int64_t getInputRank() { + return getInputType().getRank(); + } + + // Method to implement for specifying output range for + // DestinationStyleOpInterface + std::pair getDpsInitsPositionRange() { + std::pair outputsIndexAndLength = + getODSOperandIndexAndLength(1); + return std::make_pair( + outputsIndexAndLength.first, + outputsIndexAndLength.first + outputsIndexAndLength.second); + } + }]; +} + //===----------------------------------------------------------------------===// // Pure ops //===----------------------------------------------------------------------===// diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index c2e757f7a0ff..7c1361200925 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,11 +1,11 @@ add_subdirectory(TorchOnnxToTorch) set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_STABLEHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) -else() - mlir_tablegen(Passes.h.inc -gen-pass-decls) -endif() + + + +mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS}) + add_public_tablegen_target(TorchMLIRConversionPassIncGen) add_mlir_doc(Passes TorchMLIRConversionPasses ./ -gen-pass-doc) diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index ed58c699559c..2bace8e4f231 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -114,6 +114,7 @@ def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTensorPass()"; } +#ifdef TORCH_MLIR_ENABLE_TOSA def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ @@ -122,6 +123,7 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { }]; let constructor = "mlir::torch::createConvertTorchToTosaPass()"; } +#endif def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 3230cc8b46a0..431d014adc0e 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -34,6 +34,7 @@ struct OpBinder { Location getLoc() { return op->getLoc(); } int getNumOperands() { return op->getNumOperands(); } + int getNumResults() { return op->getNumResults(); } // Operand matches of different arities. ParseResult tensorOperand(Value &value0) { @@ -45,6 +46,18 @@ struct OpBinder { return success(); } + ParseResult optionalTensorOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto ot = dyn_cast(value0.getType()); + if (!ot) + return failure(); + if (!toValidTensorType(ot.getContainedType())) + return failure(); + return success(); + } + ParseResult tensorOperands(Value &value0, Value &value1) { if (op->getNumOperands() != 2) return failure(); @@ -97,6 +110,58 @@ struct OpBinder { return success(); } + // Operand matches of different arities. + ParseResult tensorListOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto tt = dyn_cast(value0.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + + ParseResult optionalTensorListOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto ot = dyn_cast(value0.getType()); + if (!ot) + return failure(); + auto tt = dyn_cast(ot.getContainedType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + + ParseResult tensorListOperandAtIndex(Value &valueIdx, int64_t idx) { + if (idx >= op->getNumOperands()) + return failure(); + valueIdx = op->getOperand(idx); + auto tt = dyn_cast(valueIdx.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + + ParseResult tensorListResultType(Torch::ListType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto tt = dyn_cast(op->getResult(0).getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + type0 = tt; + return success(); + } + ParseResult tensorResultTypes(llvm::SmallVector &typeList) { for (auto result : op->getResults()) { auto t = toValidTensorType(result.getType()); @@ -107,6 +172,54 @@ struct OpBinder { return success(); } + ParseResult optionalResultType(Torch::OptionalType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto ot = dyn_cast(op->getResult(0).getType()); + if (!ot) + return failure(); + type0 = ot; + return success(); + } + + ParseResult optionalTensorResultType(Torch::ValueTensorType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto ot = dyn_cast(op->getResult(0).getType()); + if (!ot) + return failure(); + auto t = toValidTensorType(ot.getContainedType()); + if (!t) + return failure(); + type0 = t; + return success(); + } + + ParseResult optionalTensorListResultType(Torch::ListType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto ot = dyn_cast(op->getResult(0).getType()); + if (!ot) + return failure(); + auto tt = dyn_cast(ot.getContainedType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + type0 = tt; + return success(); + } + + ParseResult tensorOperandTypes(llvm::SmallVector &typeList) { + for (auto operand : op->getOperands()) { + auto t = toValidTensorType(operand.getType()); + if (!t) + return failure(); + typeList.push_back(t); + } + return success(); + } + // The importer imports Onnx.GraphProto attributes as regions attached to the // op. ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) { @@ -172,6 +285,16 @@ struct OpBinder { return failure(); } + ParseResult optionalS64IntegerAttr(int64_t &value, StringRef nameSuffix) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + return failure(); + } + return s64IntegerAttr(value, nameSuffix); + } + ParseResult f32FloatAttr(float &value, StringRef nameSuffix, float defaultValue = 0.0f) { SmallString<64> name("torch.onnx."); @@ -216,6 +339,31 @@ struct OpBinder { return failure(); } + ParseResult f32FloatArrayAttr(llvm::SmallVector &values, + StringRef nameSuffix, + ArrayRef defaults) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + values.append(defaults.begin(), defaults.end()); + return success(); + } + if (auto arrayAttr = dyn_cast(attr)) { + for (auto element : arrayAttr) { + auto floatAttr = dyn_cast(element); + if (!floatAttr) + return failure(); + FloatType t = cast(floatAttr.getType()); + if (t.getWidth() != 32) + return failure(); + values.push_back(floatAttr.getValue().convertToFloat()); + } + return success(); + } + return failure(); + } + ParseResult stringArrayAttr(llvm::SmallVector &values, StringRef nameSuffix) { SmallString<64> name("torch.onnx."); diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 919146c6a1c7..74c2aedcd5e5 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -32,11 +32,14 @@ class Endian { namespace mlir::torch::onnx_c { +Value createActivationByName(ImplicitLocOpBuilder &b, StringRef name, + Value input); + Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, - SmallVector cstInput); + ArrayRef cstInput); -Type getQTorchTypeFromTorchIntType(Type ty); +Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty); template Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, @@ -47,6 +50,10 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, LogicalResult OnnxLstmExpander(OpBinder binder, ConversionPatternRewriter &rewriter); +LogicalResult OnnxGruExpander(OpBinder binder, + ConversionPatternRewriter &rewriter); +LogicalResult OnnxRnnExpander(OpBinder binder, + ConversionPatternRewriter &rewriter); bool areAllElementsDistinct(SmallVector array); @@ -61,12 +68,12 @@ struct onnx_list_of_constant_ints_op_binder { bool match(Operation *op) { auto constOp = dyn_cast(op); - if (!constOp || !constOp.getName().equals("onnx.Constant")) + if (!constOp || !(constOp.getName() == "onnx.Constant")) return false; if (DenseResourceElementsAttr attr = - constOp->getAttr("torch.onnx.value") - .dyn_cast_or_null()) { + dyn_cast_or_null( + constOp->getAttr("torch.onnx.value"))) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { @@ -83,6 +90,12 @@ struct onnx_list_of_constant_ints_op_binder { } return true; } + if (ElementsAttr attr = dyn_cast_or_null( + constOp->getAttr("torch.onnx.value"))) { + for (auto axis : attr.getValues()) + bind_values.push_back(axis.getSExtValue()); + return true; + } return false; } }; @@ -96,6 +109,16 @@ m_OnnxListOfConstantInts(SmallVectorImpl &bind_values) { std::optional onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx); +LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, + Location loc, Value input, int64_t dimA, + int64_t dimB, Value &transposed); + +LogicalResult createTorchPermuteOp(OpBinder binder, + ConversionPatternRewriter &rewriter, + Location loc, Value input, + SmallVector permuteDims, + Value &permuted); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 5d2095f04f14..b59d183b4084 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -97,6 +97,14 @@ getBackendTypeForScalarType(MLIRContext *context, bool isUnsignedTorchType(Type type); +LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter, + Location loc, SmallVector dimensions, + Value input, Value &result); + +// Flips an input tensor based on the values of axis list. +Value flipTensor(PatternRewriter &rewriter, Location loc, Value input, + SmallVector axis); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index 734ba81ea07a..9067b7e24665 100644 --- a/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -22,6 +22,10 @@ namespace hlo { using mlir::ConversionPatternRewriter; +// Create chlo::ConstantLikeOp +template +Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, Value val); + // Create a 32-bit float constant operator from a float Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val); @@ -46,10 +50,15 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarValue, Type dtype); Value promoteType(PatternRewriter &rewriter, Location loc, Value input, - TensorType outType); + Type outElementType); + +FailureOr>> +getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, + ArrayRef tensors, size_t dimSizeIndexBits); Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, - TensorType outType); + TensorType outType, + std::optional bcastSizeTensor); SmallVector toPositiveDims(ArrayRef dims, int64_t rank); @@ -64,21 +73,29 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, Operation *op, Value value, size_t dimSizeIndexBits); +// Get the dimension sizes of the input tensor, given the dimension axes +FailureOr> getDimIndexOfTensor(PatternRewriter &rewriter, + Operation *op, Value value, + ArrayRef inpDims); + +// Get the dimension sizes of the input tensor +FailureOr> +getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value); + // Get a tensor that unsqueezed the specified dimensions of the input tensor FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, ArrayRef inputUnsqzDims, - size_t dimSizeIndexBits); + Value tensor, + ArrayRef inputUnsqzDims); // Get a tensor that collapse the specified dimensions of the input tensor FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t collapseStartDim, - int64_t collapseEndDim, - size_t dimSizeIndexBits); + int64_t collapseEndDim); // Get a tensor that splits the specified dimensions of the input tensor FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t splitDim, - int64_t outerLength, size_t dimSizeIndexBits); + int64_t outerLength); Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db1..221745b1c26e 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -12,12 +12,25 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + #include namespace mlir { namespace torch { + +/// Collect a set of legal/illegal ops for converting Torch operations to Tosa +/// dialect. +void populateTorchToTosaConversionLegalOps(ConversionTarget &target); + +/// Collect a set of patterns to convert Torch operations to Tosa dialect + +/// return the set of illegalOps +std::set +populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, + RewritePatternSet &patterns); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 9fe25cbc17f8..8beabb969ed8 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H #define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -38,6 +38,10 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, Value conv_val, ShapedType input_type, ShapedType weight_type, ShapedType output_type); +// Create a TOSA slice op from \p start with \p size +Value buildSlice(PatternRewriter &rewriter, Value &input, + llvm::ArrayRef start, llvm::ArrayRef size); + // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type); @@ -117,10 +121,16 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, rewriter.replaceOp(op, result->getResults()); } +TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape); + +TypedValue transposeBy(Location loc, + PatternRewriter &rewriter, Value val, + ArrayRef permutation); + // Get accumulator type for AvgPool2dOp. LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType); - } // namespace tosa } // namespace mlir diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index b76efe869a0f..264fb4966d39 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy); +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy); Value castIntToIndex(OpBuilder &b, Location loc, Value v); @@ -95,6 +97,15 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4de41e13b80a..9a5a0b387082 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -256,6 +256,159 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [ }]; } +def Torch_AtenRreluOp : Torch_Op<"aten.rrelu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRreluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_ : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRrelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenRrelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoiseOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenRreluWithNoise_Op : Torch_Op<"aten.rrelu_with_noise_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_ : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoise_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenRreluWithNoise_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenSeluOp : Torch_Op<"aten.selu", [ AllowsTypeRefinement, HasValueSemantics, @@ -841,6 +994,51 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [ }]; } +def Torch_AtenExp2Op : Torch_Op<"aten.exp2", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::exp2 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExp2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenExp2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenExp2_Op : Torch_Op<"aten.exp2_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::exp2_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExp2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenExp2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [ AllowsTypeRefinement, HasValueSemantics, @@ -1338,6 +1536,51 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ }]; } +def Torch_AtenFracOp : Torch_Op<"aten.frac", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::frac : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFracOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFracOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenFrac_Op : Torch_Op<"aten.frac_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::frac_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFrac_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenFrac_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ AllowsTypeRefinement, HasValueSemantics, @@ -3255,6 +3498,53 @@ def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [ }]; } +def Torch_AtenCopysignTensorOp : Torch_Op<"aten.copysign.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::copysign.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCopysignTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCopysignTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCopysign_TensorOp : Torch_Op<"aten.copysign_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::copysign_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other + ); + let results = (outs + AnyTorchOptionalNonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCopysign_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCopysign_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement, ReadOnly @@ -3325,6 +3615,7 @@ def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3704,12 +3995,12 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ }]; } -def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ +def Torch_AtenLdexpTensorOp : Torch_Op<"aten.ldexp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::ldexp.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$other @@ -3719,24 +4010,71 @@ def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLdexpTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenEqTensorOp::print(OpAsmPrinter &printer) { + void AtenLdexpTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasFolder = 1; } -def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement +def Torch_AtenSignbitOp : Torch_Op<"aten.signbit", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { - let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::signbit : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSignbitOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSignbitOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEqTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs AnyTorchOptionalNonValueTensorType:$result @@ -4270,6 +4608,29 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ }]; } +def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, @@ -4495,254 +4856,295 @@ def Torch_AtenFakeQuantizePerTensorAffineOp : Torch_Op<"aten.fake_quantize_per_t }]; } -def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [ +def Torch_AtenFakeQuantizePerTensorAffineCachemaskOp : Torch_Op<"aten.fake_quantize_per_tensor_affine_cachemask", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::maximum : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + Torch_FloatType:$scale, + Torch_IntType:$zero_point, + Torch_IntType:$quant_min, + Torch_IntType:$quant_max ); let results = (outs - AnyTorchOptionalTensorType:$result + AnyTorchOptionalTensorType:$output, + AnyTorchOptionalTensorType:$mask ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaximumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenFakeQuantizePerTensorAffineCachemaskOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); } - void AtenMaximumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenFakeQuantizePerTensorAffineCachemaskOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); } }]; } -def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ +def Torch_AtenFakeQuantizePerTensorAffineTensorQparamsOp : Torch_Op<"aten.fake_quantize_per_tensor_affine.tensor_qparams", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::minimum : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fake_quantize_per_tensor_affine.tensor_qparams : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + Torch_IntType:$quant_min, + Torch_IntType:$quant_max ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMinimumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenFakeQuantizePerTensorAffineTensorQparamsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenMinimumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenFakeQuantizePerTensorAffineTensorQparamsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } -def Torch_AtenMishOp : Torch_Op<"aten.mish", [ +def Torch_Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp : Torch_Op<"aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mish : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams : (Tensor, Tensor, Tensor, Tensor, int, int) -> (Tensor, Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + AnyTorchTensorType:$fake_quant_enabled, + Torch_IntType:$quant_min, + Torch_IntType:$quant_max ); let results = (outs - AnyTorchOptionalTensorType:$result + AnyTorchOptionalTensorType:$output, + AnyTorchOptionalTensorType:$mask ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMishOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); } - void AtenMishOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); } }]; } -def Torch_AtenXlogyTensorOp : Torch_Op<"aten.xlogy.Tensor", [ +def Torch_AtenFakeQuantizePerChannelAffineOp : Torch_Op<"aten.fake_quantize_per_channel_affine", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fake_quantize_per_channel_affine : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + Torch_IntType:$axis, + Torch_IntType:$quant_min, + Torch_IntType:$quant_max ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenXlogyTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenFakeQuantizePerChannelAffineOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); } - void AtenXlogyTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenFakeQuantizePerChannelAffineOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); } }]; } -def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ +def Torch_AtenFakeQuantizePerChannelAffineCachemaskOp : Torch_Op<"aten.fake_quantize_per_channel_affine_cachemask", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + Torch_IntType:$axis, + Torch_IntType:$quant_min, + Torch_IntType:$quant_max + ); + let results = (outs + AnyTorchOptionalTensorType:$output, + AnyTorchOptionalTensorType:$mask + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFakeQuantizePerChannelAffineCachemaskOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); + } + void AtenFakeQuantizePerChannelAffineCachemaskOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + +def Torch_AtenIsfiniteOp : Torch_Op<"aten.isfinite", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isfinite : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsubScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenIsfiniteOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenRsubScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenIsfiniteOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasCanonicalizer = 1; } -def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ +def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::gelu : (Tensor, str) -> (Tensor)`"; + let summary = "Generated op for `aten::maximum : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_StringType:$approximate + AnyTorchTensorType:$other ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGeluOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenMaximumOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenGeluOp::print(OpAsmPrinter &printer) { + void AtenMaximumOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [ +def Torch_AtenMinimumOp : Torch_Op<"aten.minimum", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::minimum : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$exponent + AnyTorchTensorType:$other ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPowTensorScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenMinimumOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPowTensorScalarOp::print(OpAsmPrinter &printer) { + void AtenMinimumOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [ +def Torch_AtenFmaxOp : Torch_Op<"aten.fmax", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fmax : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$exponent + AnyTorchTensorType:$other ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPowTensorTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFmaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPowTensorTensorOp::print(OpAsmPrinter &printer) { + void AtenFmaxOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [ +def Torch_AtenFminOp : Torch_Op<"aten.fmin", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fmin : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchScalarType:$self, - AnyTorchTensorType:$exponent + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPowScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFminOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPowScalarOp::print(OpAsmPrinter &printer) { + void AtenFminOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ +def Torch_AtenMishOp : Torch_Op<"aten.mish", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::mish : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$grad_output, - AnyTorchTensorType:$self, - AnyTorchScalarType:$threshold + AnyTorchTensorType:$self ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenThresholdBackwardOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMishOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenThresholdBackwardOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMishOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [ +def Torch_AtenXlogyTensorOp : Torch_Op<"aten.xlogy.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::floor_divide : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$other @@ -4752,106 +5154,302 @@ def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFloorDivideOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenXlogyTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFloorDivideOp::print(OpAsmPrinter &printer) { + void AtenXlogyTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSoftplusOp : Torch_Op<"aten.softplus", [ +def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$beta, - AnyTorchScalarType:$threshold + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSoftplusOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenRsubScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenSoftplusOp::print(OpAsmPrinter &printer) { + void AtenRsubScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; + let hasCanonicalizer = 1; } -def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ +def Torch_AtenGeluOp : Torch_Op<"aten.gelu", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::prelu : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::gelu : (Tensor, str) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$weight + Torch_StringType:$approximate ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenPreluOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenGeluOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenPreluOp::print(OpAsmPrinter &printer) { + void AtenGeluOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCeluOp : Torch_Op<"aten.celu", [ +def Torch_AtenPowTensorScalarOp : Torch_Op<"aten.pow.Tensor_Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$exponent ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenPowTensorScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenCeluOp::print(OpAsmPrinter &printer) { + void AtenPowTensorScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$alpha +def Torch_AtenPowTensorTensorOp : Torch_Op<"aten.pow.Tensor_Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$exponent ); let results = (outs - AnyTorchOptionalNonValueTensorType:$result + AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenPowTensorTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenCelu_Op::print(OpAsmPrinter &printer) { + void AtenPowTensorTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenPowScalarOp : Torch_Op<"aten.pow.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$self, + AnyTorchTensorType:$exponent + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPowScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPowScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenFloatPowerTensorTensorOp : Torch_Op<"aten.float_power.Tensor_Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$exponent + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloatPowerTensorTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFloatPowerTensorTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchScalarType:$threshold + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenThresholdBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenThresholdBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::floor_divide : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFloorDivideOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFloorDivideOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenSoftplusOp : Torch_Op<"aten.softplus", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$beta, + AnyTorchScalarType:$threshold + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSoftplusOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSoftplusOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::prelu : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$weight + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPreluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPreluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenRad2degOp : Torch_Op<"aten.rad2deg", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rad2deg : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRad2degOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRad2degOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenComplexOp : Torch_Op<"aten.complex", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::complex : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$real, + AnyTorchTensorType:$imag + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenComplexOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenComplexOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; @@ -5067,6 +5665,30 @@ def Torch_AtenSoftshrinkOp : Torch_Op<"aten.softshrink", [ }]; } +def Torch_AtenPolarOp : Torch_Op<"aten.polar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::polar : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$abs, + AnyTorchTensorType:$angle + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPolarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPolarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -5880,6 +6502,30 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [ }]; } +def Torch_Aten_IntMmOp : Torch_Op<"aten._int_mm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_int_mm : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mat2 + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_IntMmOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten_IntMmOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddmmOp : Torch_Op<"aten.addmm", [ AllowsTypeRefinement, HasValueSemantics, @@ -5955,45 +6601,94 @@ def Torch_AtenMvOp : Torch_Op<"aten.mv", [ }]; } -def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ +def Torch_AtenDotOp : Torch_Op<"aten.dot", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)`"; + let summary = "Generated op for `aten::dot : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$x1, - AnyTorchTensorType:$x2, - Torch_IntType:$dim, - Torch_FloatType:$eps + AnyTorchTensorType:$self, + AnyTorchTensorType:$tensor ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCosineSimilarityOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenDotOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenCosineSimilarityOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenDotOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ +def Torch_AtenOuterOp : Torch_Op<"aten.outer", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let summary = "Generated op for `aten::outer : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - AnyTorchListOfTorchIntType:$padding, - AnyTorchListOfTorchIntType:$dilation, + AnyTorchTensorType:$self, + AnyTorchTensorType:$vec2 + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenOuterOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenOuterOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$x1, + AnyTorchTensorType:$x2, + Torch_IntType:$dim, + Torch_FloatType:$eps + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCosineSimilarityOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenCosineSimilarityOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_IntType:$groups ); let results = (outs @@ -6010,6 +6705,35 @@ def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ }]; } +def Torch_AtenConv3dPaddingOp : Torch_Op<"aten.conv3d.padding", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + Torch_StringType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv3dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv3dPaddingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6039,6 +6763,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } +def Torch_AtenConv2dPaddingOp : Torch_Op<"aten.conv2d.padding", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + Torch_StringType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv2dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv2dPaddingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6068,6 +6821,35 @@ def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ }]; } +def Torch_AtenConv1dPaddingOp : Torch_Op<"aten.conv1d.padding", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + Torch_StringType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv1dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv1dPaddingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6579,6 +7361,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$p, + Torch_IntType:$dim, + AnyTorchScalarType:$maxnorm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenRenormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -6713,6 +7522,35 @@ def Torch_AtenMaxPool1dOp : Torch_Op<"aten.max_pool1d", [ }]; } +def Torch_AtenMaxPool1dWithIndicesOp : Torch_Op<"aten.max_pool1d_with_indices", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_BoolType:$ceil_mode + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxPool1dWithIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); + } + void AtenMaxPool1dWithIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6741,6 +7579,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ }]; } +def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, @@ -6829,6 +7692,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [ }]; } +def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, @@ -6856,6 +7746,7 @@ def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", printDefaultTorchOp(printer, *this, 6, 2); } }]; + let hasCanonicalizer = 1; } def Torch_AtenMaxPool3dWithIndicesBackwardOp : Torch_Op<"aten.max_pool3d_with_indices_backward", [ @@ -7359,6 +8250,7 @@ def Torch_Aten_AdaptiveAvgPool2dOp : Torch_Op<"aten._adaptive_avg_pool2d", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasCanonicalizer = 1; } def Torch_Aten_AdaptiveAvgPool2dBackwardOp : Torch_Op<"aten._adaptive_avg_pool2d_backward", [ @@ -7582,6 +8474,7 @@ def Torch_AtenTransposeIntOp : Torch_Op<"aten.transpose.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ @@ -7879,91 +8772,164 @@ def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [ let hasCanonicalizer = 1; } -def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ +def Torch_Aten__Lshift__ScalarOp : Torch_Op<"aten.__lshift__.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_BoolType:$half_to_float + AnyTorchScalarType:$other ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult Aten_SoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult Aten__Lshift__ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void Aten_SoftmaxOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void Aten__Lshift__ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ +def Torch_Aten__Rshift__ScalarOp : Torch_Op<"aten.__rshift__.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mean : (Tensor, int?) -> (Tensor)`"; + let summary = "Generated op for `aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$dtype + AnyTorchScalarType:$other ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMeanOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult Aten__Rshift__ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMeanOp::print(OpAsmPrinter &printer) { + void Aten__Rshift__ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenStdOp : Torch_Op<"aten.std", [ +def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_BoolType:$unbiased + Torch_IntType:$dim, + Torch_BoolType:$half_to_float ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenStdOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult Aten_SoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenStdOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void Aten_SoftmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [ +def Torch_Aten_SafeSoftmaxOp : Torch_Op<"aten._safe_softmax", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalListOfTorchIntType:$dim, - Torch_BoolType:$unbiased, - Torch_BoolType:$keepdim - ); + Torch_IntType:$dim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_SafeSoftmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_SafeSoftmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mean : (Tensor, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMeanOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMeanOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenStdOp : Torch_Op<"aten.std", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::std : (Tensor, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_BoolType:$unbiased + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenStdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenStdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenStdDimOp : Torch_Op<"aten.std.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$unbiased, + Torch_BoolType:$keepdim + ); let results = (outs AnyTorchOptionalTensorType:$result ); @@ -8377,6 +9343,78 @@ def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [ }]; } +def Torch_AtenLinalgDetOp : Torch_Op<"aten.linalg_det", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_det : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgDetOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenLinalgDetOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_Aten_LinalgDetOp : Torch_Op<"aten._linalg_det", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A + ); + let results = (outs + AnyTorchOptionalTensorType:$result, + AnyTorchOptionalTensorType:$LU, + AnyTorchOptionalTensorType:$pivots + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_LinalgDetOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 3); + } + void Aten_LinalgDetOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 3); + } + }]; +} + +def Torch_AtenLinalgSlogdetOp : Torch_Op<"aten.linalg_slogdet", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_slogdet : (Tensor) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A + ); + let results = (outs + AnyTorchOptionalTensorType:$sign, + AnyTorchOptionalTensorType:$logabsdet + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgSlogdetOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 2); + } + void AtenLinalgSlogdetOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 2); + } + }]; +} + def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ AllowsTypeRefinement, HasValueSemantics, @@ -8453,6 +9491,31 @@ def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [ }]; } +def Torch_AtenL1LossOp : Torch_Op<"aten.l1_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenL1LossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenL1LossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [ AllowsTypeRefinement, HasValueSemantics, @@ -8632,6 +9695,33 @@ def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy }]; } +def Torch_AtenBinaryCrossEntropyWithLogitsOp : Torch_Op<"aten.binary_cross_entropy_with_logits", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$pos_weight, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBinaryCrossEntropyWithLogitsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenBinaryCrossEntropyWithLogitsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [ AllowsTypeRefinement, HasValueSemantics, @@ -8758,6 +9848,58 @@ def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [ }]; } +def Torch_Aten_WeightNormInterfaceOp : Torch_Op<"aten._weight_norm_interface", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$v, + AnyTorchTensorType:$g, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_WeightNormInterfaceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void Aten_WeightNormInterfaceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + +def Torch_AtenRot90Op : Torch_Op<"aten.rot90", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rot90 : (Tensor, int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$k, + AnyTorchListOfTorchIntType:$dims + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRot90Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenRot90Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, @@ -8855,6 +9997,30 @@ def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [ }]; } +def Torch_AtenReflectionPad3dOp : Torch_Op<"aten.reflection_pad3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, @@ -8950,6 +10116,7 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ @@ -8974,6 +10141,8 @@ def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenDimOp : Torch_Op<"aten.dim", [ @@ -9762,37 +10931,35 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ }]; } -def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ +def Torch_AtenAtleast1dOp : Torch_Op<"aten.atleast_1d", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`"; + let summary = "Generated op for `aten::atleast_1d : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_StringType:$equation, - AnyTorchListOfTensorType:$tensors, - AnyTorchOptionalListOfTorchIntType:$path + AnyTorchTensorType:$self ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenAtleast1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenEinsumOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenAtleast1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenTraceOp : Torch_Op<"aten.trace", [ +def Torch_AtenAtleast2dOp : Torch_Op<"aten.atleast_2d", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::trace : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::atleast_2d : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -9801,25 +10968,73 @@ def Torch_AtenTraceOp : Torch_Op<"aten.trace", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenTraceOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAtleast2dOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenTraceOp::print(OpAsmPrinter &printer) { + void AtenAtleast2dOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ +def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`"; + let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$boundaries, - Torch_BoolType:$out_int32, + Torch_StringType:$equation, + AnyTorchListOfTensorType:$tensors, + AnyTorchOptionalListOfTorchIntType:$path + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenEinsumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenTraceOp : Torch_Op<"aten.trace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::trace : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTraceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTraceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$boundaries, + Torch_BoolType:$out_int32, Torch_BoolType:$right ); let results = (outs @@ -10713,6 +11928,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenReshapeAsOp : Torch_Op<"aten.reshape_as", [ @@ -11158,6 +12374,32 @@ def Torch_AtenAminOp : Torch_Op<"aten.amin", [ }]; } +def Torch_AtenAminmaxOp : Torch_Op<"aten.aminmax", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchOptionalTensorType:$min, + AnyTorchOptionalTensorType:$max + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAminmaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenAminmaxOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ AllowsTypeRefinement, ReadOnly @@ -11394,6 +12636,29 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [ let hasFolder = 1; } +def Torch_AtenViewDtypeOp : Torch_Op<"aten.view.dtype", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::view.dtype : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenViewDtypeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenViewDtypeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ AllowsTypeRefinement, HasValueSemantics, @@ -11849,6 +13114,7 @@ def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -12135,6 +13401,34 @@ def Torch_AtenBaddbmm_Op : Torch_Op<"aten.baddbmm_", [ }]; } +def Torch_AtenHannWindowPeriodicOp : Torch_Op<"aten.hann_window.periodic", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$window_length, + Torch_BoolType:$periodic, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHannWindowPeriodicOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenHannWindowPeriodicOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ AllowsTypeRefinement, HasValueSemantics, @@ -12161,6 +13455,58 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFftRfftOp : Torch_Op<"aten.fft_rfft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$n, + Torch_IntType:$dim, + AnyTorchOptionalStringType:$norm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFftRfftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenFftRfftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$n, + Torch_IntType:$dim, + AnyTorchOptionalStringType:$norm + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFftIfftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenFftIfftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -12213,6 +13559,35 @@ def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ }]; } +def Torch_AtenUniqueDimOp : Torch_Op<"aten.unique_dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::unique_dim : (Tensor, int, bool, bool, bool) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$sorted, + Torch_BoolType:$return_inverse, + Torch_BoolType:$return_counts + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1, + AnyTorchOptionalTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUniqueDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 3); + } + void AtenUniqueDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 3); + } + }]; +} + def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ AllowsTypeRefinement, HasValueSemantics, @@ -12268,6 +13643,93 @@ def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [ let hasVerifier = 1; } +def Torch_AtenCol2imOp : Torch_Op<"aten.col2im", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$stride + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCol2imOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenCol2imOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + +def Torch_AtenKthvalueOp : Torch_Op<"aten.kthvalue", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$k, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchOptionalTensorType:$values, + AnyTorchOptionalTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenKthvalueOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 2); + } + void AtenKthvalueOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 2); + } + }]; + let hasVerifier = 1; +} + +def Torch_AtenStftOp : Torch_Op<"aten.stft", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?, bool?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$n_fft, + AnyTorchOptionalIntType:$hop_length, + AnyTorchOptionalIntType:$win_length, + AnyTorchOptionalTensorType:$window, + Torch_BoolType:$normalized, + AnyTorchOptionalBoolType:$onesided, + AnyTorchOptionalBoolType:$return_complex, + AnyTorchOptionalBoolType:$align_to_window + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenStftOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenStftOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, @@ -12340,6 +13802,59 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ }]; } +def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenAsStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()`"; + let arguments = (ins + AnyTorchTensorType:$a, + AnyTorchOptionalListOfTorchIntType:$size, + AnyTorchOptionalListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$layout + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AssertTensorMetadataOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 0); + } + void Aten_AssertTensorMetadataOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 0); + } + }]; + let hasFolder = 1; +} + def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ AllowsTypeRefinement, ReadOnly @@ -12707,6 +14222,31 @@ def Torch_AtenViewCopyDtypeOp : Torch_Op<"aten.view_copy.dtype", [ }]; } +def Torch_AtenUnfoldOp : Torch_Op<"aten.unfold", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::unfold : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dimension, + Torch_IntType:$size, + Torch_IntType:$step + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUnfoldOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUnfoldOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ AllowsTypeRefinement, HasValueSemantics, @@ -12895,6 +14435,56 @@ def Torch_AtenAsStridedScatterOp : Torch_Op<"aten.as_strided_scatter", [ }]; } +def Torch_AtenUpsampleNearest1dOp : Torch_Op<"aten.upsample_nearest1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_nearest1d : (Tensor, int[], float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchOptionalFloatType:$scales + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleNearest1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUpsampleNearest1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenUpsampleNearest1dVecOp : Torch_Op<"aten.upsample_nearest1d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleNearest1dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUpsampleNearest1dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -12921,12 +14511,90 @@ def Torch_AtenUpsampleNearest2dOp : Torch_Op<"aten.upsample_nearest2d", [ }]; } +def Torch_AtenUpsampleNearest2dVecOp : Torch_Op<"aten.upsample_nearest2d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleNearest2dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUpsampleNearest2dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenUpsampleBilinear2dOp : Torch_Op<"aten.upsample_bilinear2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalFloatType:$scales_h, + AnyTorchOptionalFloatType:$scales_w + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenUpsampleBilinear2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenUpsampleBilinear2dVecOp : Torch_Op<"aten.upsample_bilinear2d.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalListOfTorchIntType:$output_size, + Torch_BoolType:$align_corners, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleBilinear2dVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUpsampleBilinear2dVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_attention", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)`"; + let summary = "Generated op for `aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$query, AnyTorchTensorType:$key, @@ -12934,7 +14602,8 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at AnyTorchOptionalTensorType:$attn_mask, Torch_FloatType:$dropout_p, Torch_BoolType:$is_causal, - AnyTorchOptionalFloatType:$scale + AnyTorchOptionalFloatType:$scale, + Torch_BoolType:$enable_gqa ); let results = (outs AnyTorchOptionalTensorType:$result @@ -12942,10 +14611,10 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult AtenScaledDotProductAttentionOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); + return parseDefaultTorchOp(parser, result, 8, 1); } void AtenScaledDotProductAttentionOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); + printDefaultTorchOp(printer, *this, 8, 1); } }]; } @@ -12977,6 +14646,36 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [ }]; } +def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$i1, + AnyTorchTensorType:$i2, + AnyTorchTensorType:$i3, + AnyTorchListOfTorchIntType:$expand1, + AnyTorchListOfTorchIntType:$expand2, + AnyTorchListOfTorchIntType:$expand3, + AnyTorchListOfTorchIntType:$sumdim, + Torch_IntType:$unroll_dim + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 8, 1); + } + void Aten_TrilinearOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 8, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, @@ -13155,37 +14854,83 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenCatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenCatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; + let hasCanonicalizer = 1; +} + +def Torch_AtenStackOp : Torch_Op<"aten.stack", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenStackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::hstack : (Tensor[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHstackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenHstackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasFolder = 1; - let hasCanonicalizer = 1; } -def Torch_AtenStackOp : Torch_Op<"aten.stack", [ +def Torch_AtenColumnStackOp : Torch_Op<"aten.column_stack", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; + let summary = "Generated op for `aten::column_stack : (Tensor[]) -> (Tensor)`"; let arguments = (ins - AnyTorchListOfTensorType:$tensors, - Torch_IntType:$dim + AnyTorchListOfTensorType:$tensors ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenStackOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenColumnStackOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenStackOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenColumnStackOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } @@ -13501,6 +15246,31 @@ def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenTensorSplitSectionsOp : Torch_Op<"aten.tensor_split.sections", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$sections, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTensorSplitSectionsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenTensorSplitSectionsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; } def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ @@ -13550,6 +15320,54 @@ def Torch_AtenChunkOp : Torch_Op<"aten.chunk", [ }]; } +def Torch_AtenMeshgridOp : Torch_Op<"aten.meshgrid", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::meshgrid : (Tensor[]) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMeshgridOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenMeshgridOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMeshgridIndexingOp : Torch_Op<"aten.meshgrid.indexing", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors, + Torch_StringType:$indexing + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMeshgridIndexingOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMeshgridIndexingOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddStrOp : Torch_Op<"aten.add.str", [ AllowsTypeRefinement, HasValueSemantics, @@ -14101,6 +15919,7 @@ def Torch_AtenFloordivIntOp : Torch_Op<"aten.floordiv.int", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [ @@ -14150,6 +15969,7 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [ @@ -14249,6 +16069,32 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; +} + +def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`"; + let arguments = (ins + Torch_IntType:$a, + Torch_FloatType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulIntFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; } def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ @@ -14348,6 +16194,31 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ let hasFolder = 1; } +def Torch_AtenMulFloatIntOp : Torch_Op<"aten.mul.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.float_int : (float, int) -> (float)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ AllowsTypeRefinement, HasValueSemantics, @@ -14692,6 +16563,31 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ }]; } +def Torch_AtenEqBoolOp : Torch_Op<"aten.eq.bool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eq.bool : (bool, bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$a, + Torch_BoolType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEqBoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEqBoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [ AllowsTypeRefinement, HasValueSemantics, @@ -14789,6 +16685,31 @@ def Torch_Aten__Not__Op : Torch_Op<"aten.__not__", [ let hasFolder = 1; } +def Torch_Aten__Or__BoolOp : Torch_Op<"aten.__or__.bool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::__or__.bool : (bool, bool) -> (bool)`"; + let arguments = (ins + Torch_BoolType:$a, + Torch_BoolType:$b + ); + let results = (outs + Torch_BoolType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten__Or__BoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void Aten__Or__BoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ AllowsTypeRefinement, HasValueSemantics, @@ -14814,6 +16735,31 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ let hasCanonicalizer = 1; } +def Torch_AtenMulLeftTOp : Torch_Op<"aten.mul.left_t", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.left_t : (t[], int) -> (t[])`"; + let arguments = (ins + AnyTorchListType:$l, + Torch_IntType:$n + ); + let results = (outs + AnyTorchListType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulLeftTOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulLeftTOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasCanonicalizer = 1; +} + def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [ AllowsTypeRefinement, ReadOnly @@ -15081,102 +17027,185 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [ }]; } -def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ +def Torch_AtenCeilFloatOp : Torch_Op<"aten.ceil.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ceil.float : (float) -> (int)`"; + let arguments = (ins + Torch_FloatType:$a + ); + let results = (outs + Torch_IntType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeilFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::narrow : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_IntType:$start, + Torch_IntType:$length + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNarrowOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNarrowOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenNarrowTensorOp : Torch_Op<"aten.narrow.Tensor", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$start, + Torch_IntType:$length + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNarrowTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNarrowTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ceil.float : (float) -> (int)`"; + let summary = "Generated op for `aten::ScalarImplicit : (Tensor) -> (Scalar)`"; let arguments = (ins - Torch_FloatType:$a + AnyTorchTensorType:$a ); let results = (outs - Torch_IntType:$result + AnyTorchScalarType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenCeilFloatOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenScalarImplicitOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenCeilFloatOp::print(OpAsmPrinter &printer) { + void AtenScalarImplicitOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasFolder = 1; + let hasCanonicalizer = 1; } -def Torch_AtenNarrowOp : Torch_Op<"aten.narrow", [ +def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::narrow : (Tensor, int, int, int) -> (Tensor)`"; + let summary = "Generated op for `aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - Torch_IntType:$start, - Torch_IntType:$length + Torch_IntType:$row, + Torch_IntType:$col, + Torch_IntType:$offset, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNarrowOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenTriuIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); } - void AtenNarrowOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenTriuIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); } }]; + let hasVerifier = 1; } -def Torch_AtenNarrowTensorOp : Torch_Op<"aten.narrow.Tensor", [ +def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::tril_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - Torch_IntType:$dim, - AnyTorchTensorType:$start, - Torch_IntType:$length + Torch_IntType:$row, + Torch_IntType:$col, + Torch_IntType:$offset, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNarrowTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenTrilIndicesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); } - void AtenNarrowTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenTrilIndicesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); } }]; + let hasVerifier = 1; } -def Torch_AtenScalarImplicitOp : Torch_Op<"aten.ScalarImplicit", [ +def Torch_AtenDeg2radOp : Torch_Op<"aten.deg2rad", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ScalarImplicit : (Tensor) -> (Scalar)`"; + let summary = "Generated op for `aten::deg2rad : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$a + AnyTorchTensorType:$self ); let results = (outs - AnyTorchScalarType:$result + AnyTorchOptionalTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenScalarImplicitOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDeg2radOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenScalarImplicitOp::print(OpAsmPrinter &printer) { + void AtenDeg2radOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasCanonicalizer = 1; } def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ @@ -15512,6 +17541,64 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + Torch_BoolType:$self_is_result + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenRreluWithNoiseBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_AtenRreluWithNoiseFunctionalOp : Torch_Op<"aten.rrelu_with_noise_functional", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$noise, + AnyTorchScalarType:$lower, + AnyTorchScalarType:$upper, + Torch_BoolType:$training, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$noise_out + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRreluWithNoiseFunctionalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 2); + } + void AtenRreluWithNoiseFunctionalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 2); + } + }]; +} + def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, @@ -15685,6 +17772,77 @@ def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_ }]; } +def Torch_AtenSymConstrainRangeOp : Torch_Op<"aten.sym_constrain_range", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sym_constrain_range : (Scalar, int?, int?) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$size, + AnyTorchOptionalIntType:$min, + AnyTorchOptionalIntType:$max + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSymConstrainRangeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 0); + } + void AtenSymConstrainRangeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 0); + } + }]; +} + +def Torch_AtenSymConstrainRangeForSizeOp : Torch_Op<"aten.sym_constrain_range_for_size", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$size, + AnyTorchOptionalIntType:$min, + AnyTorchOptionalIntType:$max + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSymConstrainRangeForSizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 0); + } + void AtenSymConstrainRangeForSizeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 0); + } + }]; +} + +def Torch_Aten_AssertScalarOp : Torch_Op<"aten._assert_scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_assert_scalar : (Scalar, str) -> ()`"; + let arguments = (ins + AnyTorchScalarType:$self, + Torch_StringType:$assert_msg + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_AssertScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 0); + } + void Aten_AssertScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 0); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, @@ -16068,11 +18226,11 @@ def Torch_PrimsVarOp : Torch_Op<"prims.var", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `prims::var : (Tensor, int[]?, float, int?) -> (Tensor)`"; + let summary = "Generated op for `prims::var : (Tensor, int[]?, float?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$inp, AnyTorchOptionalListOfTorchIntType:$dims, - Torch_FloatType:$correction, + AnyTorchOptionalFloatType:$correction, AnyTorchOptionalIntType:$output_dtype ); let results = (outs @@ -16184,6 +18342,31 @@ def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [ }]; } +def Torch_PrimsSumOp : Torch_Op<"prims.sum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::sum : (Tensor, int[]?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$inp, + AnyTorchOptionalListOfTorchIntType:$dims, + AnyTorchOptionalIntType:$output_dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsSumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void PrimsSumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ AllowsTypeRefinement, ReadOnly @@ -16262,3 +18445,121 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ }]; } +def Torch_TorchvisionDeformConv2dOp : Torch_Op<"torchvision.deform_conv2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$offset, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$bias, + Torch_IntType:$stride_h, + Torch_IntType:$stride_w, + Torch_IntType:$pad_h, + Torch_IntType:$pad_w, + Torch_IntType:$dilation_h, + Torch_IntType:$dilation_w, + Torch_IntType:$groups, + Torch_IntType:$offset_groups, + Torch_BoolType:$use_mask + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionDeformConv2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 14, 1); + } + void TorchvisionDeformConv2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 14, 1); + } + }]; +} + +def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$rois, + Torch_FloatType:$spatial_scale, + Torch_IntType:$pooled_height, + Torch_IntType:$pooled_width, + Torch_IntType:$sampling_ratio, + Torch_BoolType:$aligned + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionRoiAlignOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void TorchvisionRoiAlignOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$rois, + Torch_FloatType:$spatial_scale, + Torch_IntType:$pooled_height, + Torch_IntType:$pooled_width + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionRoiPoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void TorchvisionRoiPoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + +def Torch_TorchvisionNmsOp : Torch_Op<"torchvision.nms", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::nms : (Tensor, Tensor, float) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$dets, + AnyTorchTensorType:$scores, + Torch_FloatType:$iou_threshold + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionNmsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void TorchvisionNmsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index f49fef0721c2..a5a58064489a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -190,7 +190,7 @@ struct torch_list_of_optional_constant_ints_op_binder { int64_t num; if (matchPattern(value, m_TorchConstantInt(&num))) bind_values.push_back(num); - else if (value.getType().isa()) + else if (isa(value.getType())) bind_values.push_back(std::nullopt); else return false; @@ -208,6 +208,37 @@ m_TorchListOfOptionalConstantInts( return detail::torch_list_of_optional_constant_ints_op_binder(bind_values); } +namespace detail { +/// Matches the constant floats stored in a `torch.prim.ListConstruct`. +struct torch_list_of_constant_floats_op_binder { + SmallVectorImpl &bind_values; + + /// Creates a matcher instance that binds the value to bvs if match succeeds. + torch_list_of_constant_floats_op_binder(SmallVectorImpl &bvs) + : bind_values(bvs) {} + + bool match(Operation *op) { + auto listConstruct = dyn_cast(op); + if (!listConstruct) + return false; + for (Value value : listConstruct.getElements()) { + double num; + if (matchPattern(value, m_TorchConstantFloat(&num))) + bind_values.push_back(num); + else + return false; + } + return true; + } +}; +} // namespace detail + +/// Matches the constant integers stored in a `torch.prim.ListConstruct`. +inline detail::torch_list_of_constant_floats_op_binder +m_TorchListOfConstantFloats(SmallVectorImpl &bind_values) { + return detail::torch_list_of_constant_floats_op_binder(bind_values); +} + namespace detail { /// Matches the constant bools stored in a `torch.ListConstruct`. struct torch_list_of_constant_bools_op_binder { @@ -238,7 +269,6 @@ inline detail::torch_list_of_constant_bools_op_binder m_TorchListOfConstantBools(SmallVectorImpl &bind_values) { return detail::torch_list_of_constant_bools_op_binder(bind_values); } - namespace detail { /// Matches the constant strs stored in a `torch.ListConstruct`. struct torch_list_of_constant_strs_op_binder { diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index f578cefe0297..4a83b97e6269 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -11,6 +11,7 @@ #define TORCH_OPS include "torch-mlir/Dialect/Torch/IR/TorchTypes.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" @@ -442,8 +443,8 @@ def Torch_PrimDictConstructOp: Torch_Op<"prim.DictConstruct", [ }]; let extraClassDeclaration = [{ - Type getKeyType() { return getType().cast().getKeyType(); } - Type getValueType() { return getType().cast().getValueType(); } + Type getKeyType() { return cast(getType()).getKeyType(); } + Type getValueType() { return cast(getType()).getValueType(); } }]; } @@ -1003,7 +1004,7 @@ def Torch_CopyToNonValueTensorOp : Torch_Op<"copy.to_tensor", [ DeclareOpInterfaceMethods, TypesMatchWith<"operand is corresponding !torch.vtensor", "result", "operand", - "$_self.cast().getWithValueSemantics()">, + "cast($_self).getWithValueSemantics()">, ]> { let summary = "Create a !torch.tensor with the same contents as the operand"; let description = [{ @@ -1036,7 +1037,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ DeclareOpInterfaceMethods, TypesMatchWith<"operand is corresponding !torch.tensor", "result", "operand", - "$_self.cast().getWithoutValueSemantics()">, + "cast($_self).getWithoutValueSemantics()">, ]> { let summary = "Create a !torch.vtensor with the same contents as the operand"; let description = [{ @@ -1064,7 +1065,7 @@ def Torch_CopyToValueTensorOp : Torch_Op<"copy.to_vtensor", [ def Torch_OverwriteTensorContentsOp : Torch_Op<"overwrite.tensor.contents", [ TypesMatchWith<"overwritten tensor type is corresponding !torch.tensor of value tensor type", "value", "overwritten", - "$_self.cast().getWithoutValueSemantics()"> + "cast($_self).getWithoutValueSemantics()"> ]> { let summary = "Ovewrite the contents of tensor with values from another."; let description = [{ @@ -1337,4 +1338,76 @@ def Torch_DtypeCalculateYieldDtypesOp : Torch_Op<"dtype.calculate.yield.dtypes", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Symbolic shape modeling ops for TorchDynamo frontend. +//===----------------------------------------------------------------------===// + +def Torch_SymbolicIntOp : Torch_Op<"symbolic_int", [Pure]> { + let summary = "Symbolic int representing a dynamic dimension"; + let description = [{ + The `torch.symbolic_int` operation captures a dynamic dimension on the + global function arguments as exported by TorchDynamo (torch.export). + It associates the shape symbols (i.e. "s0", "s1") with the + global SSA values (i.e. `%0`, `%1`) that is then referenced + to bind shapes on op results. + + Additionally, the operation annotates `min_val` and `max_val` attributes + denoting the range constraints for the dynamic dimension. This may be + useful for modeling runtime shape guards, or compile-time optimizations + based on the shape bounds (min, opt, max) on results of ops / regions. + + Example: + ``` + %0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int + %1 = torch.symbolic_int "s1" {min_val = 2, max_val = 20} : !torch.int + ``` + + In this case, we see that `s0` has the range [5, 10] and `s1` has the + range [2, 20]. When unspecified, the range constraints feeding in from + TorchDynamo default to [0, INT_MAX] (or [2, INT_MAX] in older PyTorch + releases). In either case, the interpretation (as specified by TorchDynamo) + is that the dynamic dimension is assumed to be not 0 or 1. This is not a + bug, and does not necessarily mean that the exported program will not work + for dimensions 0 or 1. For an in-depth discussion of this topic, see + [The 0/1 Specialization Problem](https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk). + }]; + let arguments = (ins + StrAttr:$symbol_name, + I64Attr:$min_val, + I64Attr:$max_val + ); + let results = (outs + Torch_IntType:$result + ); + let assemblyFormat = [{ + $symbol_name ` ` `{` `min_val` `=` $min_val `,` `max_val` `=` $max_val `}` attr-dict `:` type($result) + }]; +} + +def Torch_BindSymbolicShapeOp : Torch_Op<"bind_symbolic_shape", []> { + let summary = "Binds shape expressions to tensors using an affine map indexed by shape symbols"; + let description = [{ + The `torch.bind_symbolic_shape` operation binds shape expressions + useful to compute the dynamic dimensions of a tensor. It takes a + variadic of SSA symbols that map 1:1 to the local symbols declared + in the affine map. The affine map contains a list of affine shape + expressions for each dim where the terminals are from the declared + symbols. + + Example: + ``` + torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> + torch.bind_symbolic_shape %out0, [%0, %1, %2], affine_map<()[s0, s1, s2] -> (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> + ``` + }]; + let arguments = (ins + Torch_ValueTensorType:$operand, + Variadic:$shape_symbols, + Builtin_AffineMapAttr:$shape_expressions + ); + let results = (outs); + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + #endif // TORCH_OPS diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index c8d1c5051f28..163ed6300878 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -53,6 +53,9 @@ class BaseTensorType : public Type { /// convenient API. Type getOptionalDtype() const; + /// Get the raw optional sparse tensor encoding. + Attribute getOptionalSparsity() const; + /// Return true if this type has a list of sizes. bool hasSizes() const { return getOptionalSizes().has_value(); } @@ -93,6 +96,10 @@ class BaseTensorType : public Type { Type getWithSizesAndDtype(std::optional> optionalSizes, Type optionalDtype) const; + Type getWithSizesAndDtypeAndSparsity( + std::optional> optionalSizes, Type optionalDtype, + Attribute optionalSparsity) const; + /// Return a type with the same shape and dtype as this one, but with /// value semantics. ValueTensorType getWithValueSemantics() const; @@ -129,23 +136,31 @@ namespace Torch { inline std::optional> BaseTensorType::getOptionalSizes() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalSizes(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalSizes(); llvm_unreachable("not a BaseTensorType!"); } inline Type BaseTensorType::getOptionalDtype() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalDtype(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalDtype(); llvm_unreachable("not a BaseTensorType!"); } +inline Attribute BaseTensorType::getOptionalSparsity() const { + if (auto tensor = mlir::dyn_cast(*this)) + return tensor.getOptionalSparsity(); + if (auto tensor = mlir::dyn_cast(*this)) + return tensor.getOptionalSparsity(); + llvm_unreachable("not a BaseTensorType!"); +} + inline bool BaseTensorType::classof(Type type) { - return type.isa(); + return mlir::isa(type); } } // namespace Torch diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index e7fc4bc976bb..367b08610cd8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -199,7 +199,7 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> { } def AnyTorchTensorType : Type< - CPred<"$_self.isa<::mlir::torch::Torch::BaseTensorType>()">, + CPred<"isa<::mlir::torch::Torch::BaseTensorType>($_self)">, "Any Torch tensor type" >; @@ -315,6 +315,16 @@ def Torch_QInt8Type : Torch_Type<"QInt8", "qint8"> { }]; } +def Torch_QInt16Type : Torch_Type<"QInt16", "qint16"> { + let summary = "Type modeling `ScalarType::QInt16`, which doesn't yet exist"; + let description = [{ + Pytorch does not have 16-bit integer quantization support. + + This torch type is added to provide a target for 16-bit quantization + schemes coming from imported onnx models. + }]; +} + def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { let summary = "Type modeling `ScalarType::QUInt8`"; let description = [{ @@ -410,11 +420,11 @@ def AnyTorchOptionalDeviceType: def AnyTorchOptionalGeneratorType: OptionalOf; -def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">; +def IsListTypePred : CPred<"isa<::mlir::torch::Torch::ListType>($_self)">; class ListOf allowedTypes, string descr> : ContainerType, IsListTypePred, - "$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()", + "cast<::mlir::torch::Torch::ListType>($_self).getContainedType()", descr, "::mlir::torch::Torch::ListType">; def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index d4cceb05d59f..13d3a8de9463 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -73,12 +73,22 @@ struct TorchLoweringPipelineOptions void createTorchScriptModuleToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that lowers the graph IR that is produced by +/// TorchDynamo export into the form expected by torch-verify-backend-contract. +void createTorchDynamoExportToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); + /// Creates a pipeline that lowers a flat list of funcs and global slots /// with the torch and aten dialects and mutable arrays and converts it to /// the form required by torch-verify-backend-contract. void createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that lowers the torch Onnx IR that is produced by +/// Onnx import into the form expected by torch-verify-backend-contract. +void createTorchOnnxToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options); + /// Creates a pipeline that simplifies the computations in the program. /// This pass does not do any global program restructuring -- it works entirely /// within a single semantic model of a `builtin.module` with @@ -144,6 +154,12 @@ StringRef getAbstractInterpLibrary(); static const char kTorchOpPrefix[] = R"(torch.)"; +void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, + MLIRContext *context); + +std::unique_ptr> +createRestructureNonConstantAxesPass(); + } // namespace Torch /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 6439feb394be..e6b19201e85b 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -431,4 +431,24 @@ def VerifyBackendContractNoDecompositions }]; } +def RestructureNonConstantAxes + : Pass<"torch-restructure-non-constant-axes", "func::FuncOp"> { + let summary = "Ensure that every Reduction.cpp op has a constant reduction axis."; + let constructor = [{ + mlir::torch::Torch::createRestructureNonConstantAxesPass() + }]; + let description = [{ + This pass ensures that every Reduction.cpp op has a constant reduction axis. + + It does so using reshapes. For example, a <1,2,3,4,5> tensor will be reshaped to a tensor + and reduced on axis 1 to produce a tensor. The resulting tensor will be reshaped back to the original shape. + + Then when the axis is supplied at runtime (say axis = -2), the shapes will be computed as so: + becomes <6,4,5> + which gets reduced to <6,1,5> + and rehsaped back to the original reduction op's output shape, + <1,2,3,1,5> + }]; +} + #endif // TORCHMLIR_TORCH_PASSES diff --git a/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h b/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h new file mode 100644 index 000000000000..e29054790e5c --- /dev/null +++ b/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H +#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace torch { +namespace Torch { + +// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with +// a new dense level inserted at `dim`. +FailureOr getSparsityWithDenseLTAtDim(Attribute attr, Value dim); + +} // namespace Torch +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index 043dd92549b2..380af5f829c9 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -86,24 +86,34 @@ enum class TypeKind { // at:: and c10:: parts of the macro are never used within the compiler -- we // only use this for the enum values. #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(at::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(c10::complex, ComplexHalf) /* 8 */ \ - _(c10::complex, ComplexFloat) /* 9 */ \ - _(c10::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(c10::qint8, QInt8) /* 12 */ \ - _(c10::quint8, QUInt8) /* 13 */ \ - _(c10::qint32, QInt32) /* 14 */ \ - _(at::BFloat16, BFloat16) /* 15 */ \ - _(c10::quint4x2, QUInt4x2) /* 16 */ \ - _(c10::quint2x4, QUInt2x4) /* 17 */ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(at::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(c10::complex, ComplexHalf) /* 8 */ \ + _(c10::complex, ComplexFloat) /* 9 */ \ + _(c10::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(c10::qint8, QInt8) /* 12 */ \ + _(c10::quint8, QUInt8) /* 13 */ \ + _(c10::qint32, QInt32) /* 14 */ \ + _(at::BFloat16, BFloat16) /* 15 */ \ + _(c10::quint4x2, QUInt4x2) /* 16 */ \ + _(c10::quint2x4, QUInt2x4) /* 17 */ \ + _(c10::bits1x8, Bits1x8) /* 18 */ \ + _(c10::bits2x4, Bits2x4) /* 19 */ \ + _(c10::bits4x2, Bits4x2) /* 20 */ \ + _(c10::bits8, Bits8) /* 21 */ \ + _(c10::bits16, Bits16) /* 22 */ \ + _(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \ + _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \ + _(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \ + _(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \ + _(c10::qint16, QInt16) /* 27 */ enum class ScalarType : int8_t { #define DEFINE_ENUM(_1, n) n, @@ -135,6 +145,8 @@ ScalarType promote_skip_undefined(ScalarType a, ScalarType b); //===----------------------------------------------------------------------===// enum Reduction { None, Mean, Sum, END }; +Reduction get_loss_reduction_enum(const llvm::StringRef &reduce); + //===----------------------------------------------------------------------===// // Possible values for `memory_format` argument in PyTorch ops that support it. // Source: diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 33a1c9f91fe7..168330b53e9d 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -17,10 +17,26 @@ namespace mlir { namespace torch { namespace Torch { +class BaseTensorType; int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); +Value toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput); bool getListConstructElements(Value v, SmallVectorImpl &elems); + +/// Returns a torch.list of the given vals as torch.constant.int. +Value toTorchList(Location loc, PatternRewriter &rewriter, + ArrayRef vals); + +/// Broadcast the given value of tensor type to the new shape. +TypedValue broadcastTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape); + +/// Reshapes the given value of tensor type to the new shape. +TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape); + /// Returns the index indicated by `v` for a list of given `length`. /// If the index is negative, it is adjusted to `length` + `v`. /// `None` is returned the index is not an integer in the range [0,`length). @@ -72,6 +88,9 @@ bool isBuiltInType(Type type); // std::nullopt is returned if the tensorRank can't be determined. std::optional getTensorRank(Value tensor); +// Helper function to get the number of elements in a tensor. +std::optional getTensorNumel(Value tensor); + bool isViewLikeOp(Operation *op); Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, @@ -84,6 +103,10 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); +ValueTensorType getTensorTypeFromShapeValues(ArrayRef shapes, + Type dtype); +Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim); + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, @@ -138,8 +161,14 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, // control the behavior. Such support would be done in coordination with // the fx_importer and APIs, which could add hints to the IR (based on // Torch flags, user options, etc). +// Note: The special case of int8 intentionally deviates from the reference, and +// uses int32 instead of int64 accumulation. Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); +LogicalResult getPermutedType(BaseTensorType inType, + SmallVector permuteDims, + Type &permutedType); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td index bbc176feb4d4..f7bb2775385b 100644 --- a/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td +++ b/include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td @@ -25,9 +25,7 @@ class TorchConversion_Op traits = []> // Conversions to backend types. //===----------------------------------------------------------------------===// -def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [ - DeclareOpInterfaceMethods - ]> { +def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor"> { let summary = "Convert a `!torch.vtensor` to a `tensor`"; let description = [{ This op only operates on ValueTensorType, to avoid conflating conversions diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h index de188b4f4e8f..b0a085eab7f0 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h @@ -25,6 +25,11 @@ void getBackendTypeConversionDependentDialects(DialectRegistry ®istry); /// boundary (which currently consist only of builtin types). void setupBackendTypeConversion(ConversionTarget &target, TypeConverter &typeConverter); + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +void setupBackendTypeConversionForStablehlo(ConversionTarget &target, + TypeConverter &typeConverter); +#endif } // namespace TorchConversion } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt index 77e46eb4be04..51126f544a42 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,9 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_STABLEHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) -else() - mlir_tablegen(Passes.h.inc -gen-pass-decls) -endif() + +mlir_tablegen(Passes.h.inc -gen-pass-decls ${TORCH_MLIR_TABLEGEN_FLAGS}) + add_public_tablegen_target(TorchMLIRTorchConversionPassIncGen) add_mlir_doc(Passes TorchMLIRTorchConversionTransforms ./ -gen-pass-doc) diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index 2f70cf990219..67d9626cfc0b 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -22,14 +22,34 @@ class ModuleOp; namespace torch { namespace TorchConversion { +struct TorchBackendToLinalgOnTensorsBackendPipelineOptions + : public PassPipelineOptions< + TorchBackendToLinalgOnTensorsBackendPipelineOptions> { + PassOptions::Option verify{ + *this, "verify", + llvm::cl::desc("verify the backend contract after lowering"), + llvm::cl::init(true)}; + PassOptions::Option useMlprogram{ + *this, "use-mlprogram", + llvm::cl::desc("run convert-torch-conversion-to-mlprogram"), + llvm::cl::init(true)}; +}; + /// Creates a pipeline that lowers from the torch backend contract to the /// linalg-on-tensors backend contract. -void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); +void createTorchBackendToLinalgOnTensorsBackendPipeline( + OpPassManager &pm, + const TorchBackendToLinalgOnTensorsBackendPipelineOptions &options); +// Do not register the TOSA options if the TOSA target is disabled +#ifdef TORCH_MLIR_ENABLE_TOSA /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); +std::unique_ptr> createVerifyTosaBackendContractPass(); +#endif // TORCH_MLIR_ENABLE_TOSA + // Do not register the stablehlo options if the stablehlo target is disabled #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct StablehloBackendPipelineOptions @@ -48,9 +68,16 @@ struct StablehloBackendPipelineOptions void createTorchBackendToStablehloBackendPipeline( OpPassManager &pm, const StablehloBackendPipelineOptions &options); + +std::unique_ptr> +createFuncBackendTypeConversionForStablehloPass(); + +std::unique_ptr> +createFinalizingBackendTypeConversionForStablehloPass(); + std::unique_ptr> createVerifyStablehloBackendContractPass(); -#endif +#endif // TORCH_MLIR_ENABLE_STABLEHLO std::unique_ptr> createFuncBackendTypeConversionPass(); @@ -70,8 +97,6 @@ createConvertCustomQuantOpPass(); std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); -std::unique_ptr> createVerifyTosaBackendContractPass(); - } // namespace TorchConversion /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 73654c6f8034..6f70a6584022 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -21,6 +21,17 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu }]; } +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +def FuncBackendTypeConversionForStablehlo : Pass<"torch-func-backend-type-conversion-for-stablehlo", "ModuleOp"> { + let summary = "Convert functions to operate on builtin tensors for stablehlo backend"; + let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForStablehloPass()"; + let description = [{ + Partial type conversion pass analogous in scope to the upstream + `func-bufferize` pass. See details there. + }]; +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO + def FinalizingBackendTypeConversion : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { let summary = "Finalizes a partial conversion to builtin tensors"; @@ -32,15 +43,30 @@ def FinalizingBackendTypeConversion }]; } +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +def FinalizingBackendTypeConversionForStablehlo + : InterfacePass<"torch-finalizing-backend-type-conversion-for-stablehlo", "mlir::FunctionOpInterface"> { + let summary = "Finalizes a partial conversion to builtin tensors for stablehlo"; + let constructor = + "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()"; + let description = [{ + Analogous in scope to the upstream `finalizing-bufferize` pass. + See details there. + }]; +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO + def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { let summary = "Verifies conformity to the linalg-on-tensors backend contract"; let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; } +#ifdef TORCH_MLIR_ENABLE_TOSA def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "ModuleOp"> { let summary = "Verifies conformity to the linalg-on-tensors backend contract"; let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; } +#endif #ifdef TORCH_MLIR_ENABLE_STABLEHLO def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { diff --git a/include/torch-mlir/InitAll.h b/include/torch-mlir/InitAll.h index 42eb3c6a1ffb..19b2c474d787 100644 --- a/include/torch-mlir/InitAll.h +++ b/include/torch-mlir/InitAll.h @@ -18,6 +18,9 @@ namespace torch { // Registers all dialects that this project produces and any dependencies. void registerAllDialects(mlir::DialectRegistry ®istry); +// Registers all necessary dialect extensions for this project +void registerAllExtensions(mlir::DialectRegistry ®istry); + // Registers dialects that may be needed to parse torch-mlir inputs and // test cases. void registerOptionalInputDialects(mlir::DialectRegistry ®istry); diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index f4a9ca032fce..6402e44a3701 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -26,7 +26,7 @@ bool torchMlirTypeIsValidSubtype(MlirType subtype, MlirType type) { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNnModule(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, @@ -43,7 +43,7 @@ MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchOptional(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { @@ -51,7 +51,7 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { } MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getContainedType()); } @@ -64,7 +64,7 @@ MlirTypeID torchMlirTorchOptionalTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchTuple(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchTupleTypeGet(MlirContext context, @@ -77,12 +77,12 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, } size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return type.getContainedTypes().size(); } MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getContainedTypes()[pos]); } @@ -95,7 +95,7 @@ MlirTypeID torchMlirTorchTupleTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchUnion(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchUnionTypeGet(MlirContext context, @@ -108,12 +108,12 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context, } size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return type.getContainedTypes().size(); } MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getContainedTypes()[pos]); } @@ -126,7 +126,7 @@ MlirTypeID torchMlirTorchUnionTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchList(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchListTypeGet(MlirType containedType) { @@ -134,7 +134,7 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) { } MlirType torchMlirTorchListTypeGetContainedType(MlirType t) { - return wrap(unwrap(t).cast().getContainedType()); + return wrap(cast(unwrap(t)).getContainedType()); } MlirTypeID torchMlirTorchListTypeGetTypeID() { @@ -146,7 +146,7 @@ MlirTypeID torchMlirTorchListTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchDevice(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchDeviceTypeGet(MlirContext context) { @@ -162,7 +162,7 @@ MlirTypeID torchMlirTorchDeviceTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchGenerator(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) { @@ -178,7 +178,7 @@ MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchBool(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchBoolTypeGet(MlirContext context) { @@ -194,7 +194,7 @@ MlirTypeID torchMlirTorchBoolTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchInt(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchIntTypeGet(MlirContext context) { @@ -210,7 +210,7 @@ MlirTypeID torchMlirTorchIntTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchFloat(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchFloatTypeGet(MlirContext context) { @@ -226,7 +226,7 @@ MlirTypeID torchMlirTorchFloatTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchLinearParams(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) { @@ -242,7 +242,7 @@ MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchQInt8(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchQInt8TypeGet(MlirContext context) { @@ -258,7 +258,7 @@ MlirTypeID torchMlirTorchQInt8TypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchQUInt8(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) { @@ -269,12 +269,28 @@ MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { return wrap(Torch::QUInt8Type::getTypeID()); } +//===----------------------------------------------------------------------===// +// torch.qint16 type. +//===----------------------------------------------------------------------===// + +bool torchMlirTypeIsATorchQInt16(MlirType t) { + return isa(unwrap(t)); +} + +MlirType torchMlirTorchQInt16TypeGet(MlirContext context) { + return wrap(Torch::QInt16Type::get(unwrap(context))); +} + +MlirTypeID torchMlirTorchQInt16TypeGetTypeID() { + return wrap(Torch::QInt16Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNonValueTensor(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNonValueTensorTypeGet(MlirContext context, @@ -297,26 +313,26 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation( MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) { auto attrTensorType = - unwrap(attr).cast().getType().cast(); + cast(cast(unwrap(attr)).getType()); return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(), attrTensorType.getShape(), attrTensorType.getElementType())); } int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) { - return unwrap(t).cast().getSizes().size(); + return cast(unwrap(t)).getSizes().size(); } bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return cast(unwrap(t)).hasSizes(); } bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return cast(unwrap(t)).hasDtype(); } int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { - auto tensorType = unwrap(t).cast(); + auto tensorType = cast(unwrap(t)); bool hasSizes = tensorType.hasSizes(); if (!hasSizes) return -1; @@ -329,7 +345,7 @@ int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { } MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) { - return wrap(unwrap(t).cast().getDtype()); + return wrap(cast(unwrap(t)).getDtype()); } MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() { @@ -341,7 +357,7 @@ MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchValueTensor(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchValueTensorTypeGet(MlirContext context, @@ -364,26 +380,26 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation( MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) { auto attrTensorType = - unwrap(attr).cast().getType().cast(); + cast(cast(unwrap(attr)).getType()); return wrap(Torch::ValueTensorType::get(attrTensorType.getContext(), attrTensorType.getShape(), attrTensorType.getElementType())); } int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) { - return unwrap(t).cast().getSizes().size(); + return cast(unwrap(t)).getSizes().size(); } bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return cast(unwrap(t)).hasSizes(); } bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return cast(unwrap(t)).hasDtype(); } int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { - auto tensorType = unwrap(t).cast(); + auto tensorType = cast(unwrap(t)); bool hasSizes = tensorType.hasSizes(); if (!hasSizes) return -1; @@ -396,7 +412,7 @@ int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { } MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) { - return wrap(unwrap(t).cast().getDtype()); + return wrap(cast(unwrap(t)).getDtype()); } MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() { @@ -408,7 +424,7 @@ MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNone(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNoneTypeGet(MlirContext context) { @@ -424,7 +440,7 @@ MlirTypeID torchMlirTorchNoneTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchString(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchStringTypeGet(MlirContext context) { @@ -440,7 +456,7 @@ MlirTypeID torchMlirTorchStringTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchAny(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchAnyTypeGet(MlirContext context) { @@ -456,7 +472,7 @@ MlirTypeID torchMlirTorchAnyTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchNumber(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchNumberTypeGet(MlirContext context) { @@ -472,7 +488,7 @@ MlirTypeID torchMlirTorchNumberTypeGetTypeID() { //===----------------------------------------------------------------------===// bool torchMlirTypeIsATorchDict(MlirType t) { - return unwrap(t).isa(); + return isa(unwrap(t)); } MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) { @@ -487,12 +503,12 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType, } MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getKeyType()); } MlirType torchMlirTorchDictTypeGetValueType(MlirType t) { - auto type = unwrap(t).cast(); + auto type = cast(unwrap(t)); return wrap(type.getValueType()); } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c0b622005900..7f0924b143f4 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -13,7 +13,7 @@ set(LinkedLibs MLIRMemRefDialect MLIRSCFDialect MLIRTensorDialect - MLIRTosaDialect + MLIRTensorInferTypeOpInterfaceImpl MLIRSupport # Dialects. @@ -32,7 +32,11 @@ set(LinkedLibs ) if(TORCH_MLIR_ENABLE_STABLEHLO) -list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses) + list(APPEND LinkedLibs StablehloLinalgTransforms StablehloPasses) +endif() + +if(TORCH_MLIR_ENABLE_TOSA) + list(APPEND LinkedLibs MLIRTosaDialect) endif() if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 2f4e0dd1df69..0b8d8ed1d930 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -3,7 +3,9 @@ add_subdirectory(TorchToArith) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToTensor) -add_subdirectory(TorchToTosa) +if(TORCH_MLIR_ENABLE_TOSA) + add_subdirectory(TorchToTosa) +endif() if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) endif() @@ -16,13 +18,15 @@ set(linked_libs TorchMLIRTorchToArith TorchMLIRTorchToLinalg TorchMLIRTorchToSCF TorchMLIRTorchToTensor - TorchMLIRTorchToTosa TorchMLIRTorchToTMTensor TorchMLIRTorchConversionToMLProgram TorchMLIRConversionUtils) if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND linked_libs TorchMLIRTorchToStablehlo) endif() +if(TORCH_MLIR_ENABLE_TOSA) + list(APPEND linked_libs TorchMLIRTorchToTosa) +endif() add_mlir_library(TorchMLIRConversionPasses Passes.cpp diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 6d8adbaa146d..97b9b946abcf 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -19,7 +19,10 @@ #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" + +#ifdef TORCH_MLIR_ENABLE_TOSA #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#endif // TORCH_MLIR_ENABLE_TOSA //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index 6a00e5190f4b..ddcfab78ac8f 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -13,10 +13,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" @@ -63,6 +59,13 @@ class ConvertGetNextSeedOp : public OpConversionPattern { matchAndRewrite(GetNextSeedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); + + // Check for global seed and create if it doesn't exist. + auto module = op->getParentOfType(); + OpBuilder b(module.getBodyRegion()); + if (failed(getOrCreateGlobalVariableForSeed(b, module))) + return failure(); + // Generate sequence for getting the next seed with LCG step: // nextSeed = (multiplier * currentSeed + incrementStep) mod 2^64. // Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator. @@ -119,11 +122,6 @@ class ConvertTorchConversionToMLProgram typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - auto module = getOperation(); - OpBuilder b(module.getBodyRegion()); - if (failed(getOrCreateGlobalVariableForSeed(b, module))) - signalPassFailure(); - RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt index ef3e51d45288..9f55ba906fc6 100644 --- a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch DefaultDomainAtoF.cpp DefaultDomainGtoP.cpp DefaultDomainQtoZ.cpp - OnnxLstmExpander.cpp + OnnxRecurrentLayerOpExpanders.cpp Passes.cpp Patterns.cpp TorchOnnxToTorch.cpp diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 889a5fe88704..d8517fbd156d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -7,34 +7,16 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "llvm/Support/FormatVariadic.h" +#include using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; -static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, - Location loc, Value input, - int64_t dimA, int64_t dimB, - Value &transposed) { - Type transposedType; - if (failed(getTransposedType(cast(input.getType()), - dimA, dimB, transposedType))) - return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); - return success(); -} - namespace { LogicalResult windowFunctionImpl(OpBinder binder, ConversionPatternRewriter &rewriter, @@ -63,7 +45,7 @@ LogicalResult windowFunctionImpl(OpBinder binder, // Create an f32 ValueTensorType with thse same size as size, the // operand auto shapeOfOperand = - size.getType().dyn_cast().getOptionalSizes(); + dyn_cast(size.getType()).getOptionalSizes(); auto f32ResultType = rewriter.getType( shapeOfOperand, rewriter.getF32Type()); Value periodicSizeFloat = b.create( @@ -355,48 +337,127 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("BatchNormalization", 15, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value input, weight, bias, runningMean, runningVar; - bool training; - float momentum, eps; - if (binder.s64BoolAttr(training, "training_mode", 0)) - return failure(); - if (training) { - // TODO: Add support for training = true - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: training = true"); - } - - if (binder.tensorOperandAtIndex(input, 0) || - binder.tensorOperandAtIndex(weight, 1) || - binder.tensorOperandAtIndex(bias, 2) || - binder.tensorOperandAtIndex(runningMean, 3) || - binder.tensorOperandAtIndex(runningVar, 4) || - binder.f32FloatAttr(momentum, "momentum", 0.9f) || - binder.f32FloatAttr(eps, "epsilon", 1e-05f) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "BatchNormalization", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, weight, bias, inputMean, inputVar; + bool training; + float momentum, eps; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.tensorOperandAtIndex(bias, 2) || + binder.tensorOperandAtIndex(inputMean, 3) || + binder.tensorOperandAtIndex(inputVar, 4) || + binder.f32FloatAttr(momentum, "momentum", 0.9f) || + binder.f32FloatAttr(eps, "epsilon", 1e-05f) || + binder.s64BoolAttr(training, "training_mode", 0) || + binder.tensorResultTypeAtIndex(resultType, 0)) + return failure(); - Value cstFalse = rewriter.create( - binder.getLoc(), false); - Value cstMomentum = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(momentum)); - Value cstEps = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(eps)); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, input, weight, bias, runningMean, - runningVar, /*training=*/cstFalse, cstMomentum, cstEps, - /*cudnn_enabled=*/cstFalse); - return success(); - }); + Location loc = binder.getLoc(); + Value cstFalse = rewriter.create(loc, false); + Value cstMomentum = rewriter.create( + loc, rewriter.getF64FloatAttr(momentum)); + Value cstEps = rewriter.create( + loc, rewriter.getF64FloatAttr(eps)); + + // When training_mode=False, the op outputs only Y, where + // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + + // B + if (!training) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, inputMean, inputVar, + /*training=*/cstFalse, cstMomentum, cstEps, + /*cudnn_enabled=*/cstFalse); + return success(); + } + + Torch::ValueTensorType meanResultType, varResultType; + if (binder.tensorResultTypeAtIndex(meanResultType, 1) || + binder.tensorResultTypeAtIndex(varResultType, 2)) + return failure(); + + // When training_mode=True, the outputs are as follows: + // Y, running_mean, running_var. + // Y = (X - current_mean) / sqrt(current_var + epsilon) * + // scale + B + // running_mean = input_mean * momentum + current_mean * (1 - + // momentum) + // running_var = input_var * momentum + current_var * (1 - + // momentum) + // and + // current_mean = ReduceMean(X, axis=all_except_channel_index) + // current_var = ReduceVar(X, axis=all_except_channel_index) + + Torch::ValueTensorType inputType = + cast(input.getType()); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected input to have sizes"); + + // Computing current_mean and current_var. + int64_t inputRank = inputType.getSizes().size(); + // Reduce all dimensions except channel dim. + SmallVector dimsToReduce; + for (int64_t i = 0; i < inputRank; i++) { + if (i != 1) + dimsToReduce.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value reduceDimsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimsToReduce); + Value noneVal = rewriter.create(binder.getLoc()); + Value currentMean = rewriter.create( + loc, meanResultType, input, reduceDimsList, + /*keepdim=*/cstFalse, + /*dtype=*/noneVal); + Value currentVar = rewriter.create( + loc, varResultType, input, reduceDimsList, + /*unbiased=*/cstFalse, + /*keepdim=*/cstFalse); + + // Computing running_mean. + Value inputMeanMulMomentum = rewriter.create( + loc, meanResultType, inputMean, cstMomentum); + Value currentMeanMulMomentum = rewriter.create( + loc, varResultType, currentMean, cstMomentum); + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value inpMeanMMSubCurMeanMM = rewriter.create( + loc, meanResultType, inputMeanMulMomentum, currentMeanMulMomentum, + constantOne); + Value runningMean = rewriter.create( + loc, meanResultType, inpMeanMMSubCurMeanMM, currentMean, + constantOne); + + // Computing running_var. + Value inputVarMulMomentum = rewriter.create( + loc, varResultType, inputVar, cstMomentum); + Value currentVarMulMomentum = rewriter.create( + loc, varResultType, currentVar, cstMomentum); + Value inpVarMMSubCurVarMM = rewriter.create( + loc, varResultType, inputVarMulMomentum, currentVarMulMomentum, + constantOne); + Value runningVar = rewriter.create( + loc, varResultType, inpVarMMSubCurVarMM, currentVar, constantOne); + + // Computing Y. + Value y = rewriter.create( + loc, resultType, input, weight, bias, currentMean, currentVar, + /*training=*/cstFalse, cstMomentum, cstEps, + /*cudnn_enabled=*/cstFalse); + + rewriter.replaceOp(binder.op, {y, runningMean, runningVar}); + return success(); + }); patterns.onOp( "AveragePool", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; - SmallVector dilation; + SmallVector dilations; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); if (autoPad != "NOTSET") { @@ -404,13 +465,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "unsupported conversion: auto_pad != NOTSET"); } - if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) { - return failure(); - } - if (dilation.size() > 0) { - return rewriter.notifyMatchFailure( - binder.op, "dilation is not supported by torch.aten.avgpool op"); - } Torch::ValueTensorType resultType; Value operand; @@ -453,19 +507,42 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "strides list size does not match the number of axes"); } - SmallVector cstKernel, cstPadding, cstStrides; + SmallVector cstKernel, cstPadding, cstStridesDilations; for (int64_t i : kernel) { cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - for (int64_t i : padding) { + // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] + // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all + // axes x. + int64_t paddingSizeHalf = padding.size() / 2; + for (int64_t i = 0; i < paddingSizeHalf; ++i) { + // Check if onnx padding attribute is symmetric. + if (padding[i] != padding[i + paddingSizeHalf]) + return rewriter.notifyMatchFailure( + binder.op, "onnx padding attribute is not symmetric"); cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { - cstStrides.push_back(rewriter.create( + cstStridesDilations.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } + + // No dilations attribute in pytorch avgpool op, so use this trick to + // encode dilation into strides. Then in the following torchtolinalg + // lowering, decode strides into strides + dilation. + // [strideDim1,strideDim2,...,dilationDim1,dilationDim2,...] + if (binder.s64IntegerArrayAttr( + dilations, "dilations", + llvm::SmallVector(rank - 2, 1))) { + return failure(); + } + for (auto dilation : dilations) { + cstStridesDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dilation))); + } + Value kernelSizeList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), @@ -474,10 +551,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value stridesList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstStrides); + Value stridesDilationsList = + rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstStridesDilations); Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); Value cstCountIncludePad = rewriter.create( @@ -486,19 +565,22 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (rank == 3) { rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad); + binder.op, resultType, operand, kernelSizeList, + stridesDilationsList, paddingList, cstCeilMode, + cstCountIncludePad); return success(); } else if (rank == 4) { rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + binder.op, resultType, operand, kernelSizeList, + stridesDilationsList, paddingList, cstCeilMode, + cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } else if (rank == 5) { rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + binder.op, resultType, operand, kernelSizeList, + stridesDilationsList, paddingList, cstCeilMode, + cstCountIncludePad, /*divisor_override=*/cstNone); return success(); } @@ -725,6 +807,130 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, maxExpression, minExpression, constantOne); return success(); }); + patterns.onOp( + "CenterCropPad", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, shape; + if (binder.tensorOperands(input, shape) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTy = cast(input.getType()); + SmallVector inputShape(inputTy.getSizes()); + SmallVector resultShape(resultType.getSizes()); + int64_t rank = inputShape.size(); + + SmallVector axes, defaultAxes(rank); + std::iota(defaultAxes.begin(), defaultAxes.end(), 0); + if (binder.s64IntegerArrayAttr(axes, "axes", defaultAxes)) { + return failure(); + } + int64_t axesSize = axes.size(); + + Value none = rewriter.create(binder.getLoc()); + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + auto scalarTensorType = rewriter.getType( + ArrayRef{}, rewriter.getIntegerType(64, /*signed*/ 1)); + auto selectTensorType = rewriter.getType( + ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)); + + int64_t lastChangeDim = 0; + llvm::SmallVector interShape(inputShape); + for (int i = 0; i < rank; i++) { + if (inputShape[i] != resultShape[i]) { + interShape[i] = -1; + lastChangeDim = i; + } + if (interShape[i] == ShapedType::kDynamic) + interShape[i] = Torch::kUnknownSize; + } + auto interType = rewriter.getType( + interShape, resultType.getOptionalDtype()); + + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("floor")); + for (int i = 0; i < axesSize; i++) { + if (axes[i] < 0) + axes[i] += rank; + if (inputShape[axes[i]] == resultShape[axes[i]]) + continue; + + auto opType = axes[i] == lastChangeDim ? resultType : interType; + Value axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])); + Value k = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value kTensor = rewriter.create( + binder.getLoc(), scalarTensorType, k); + Value sel = rewriter.create( + binder.getLoc(), selectTensorType, shape, cstZero, kTensor); + Value outputDimSize = rewriter.create( + binder.getLoc(), rewriter.getType(), sel); + Value inputDimSize = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axes[i]))); + + if (inputShape[axes[i]] > resultShape[axes[i]]) { + Value sub = rewriter.create( + binder.getLoc(), inputDimSize, outputDimSize); + Value subTensor = rewriter.create( + binder.getLoc(), scalarTensorType, sub); + Value div = rewriter.create( + binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); + Value start = rewriter.create( + binder.getLoc(), rewriter.getType(), div); + Value end = rewriter.create( + binder.getLoc(), start, outputDimSize); + input = rewriter.create( + binder.getLoc(), opType, input, axis, start, end, cstOne); + } else { + Value sub = rewriter.create( + binder.getLoc(), outputDimSize, inputDimSize); + Value subTensor = rewriter.create( + binder.getLoc(), scalarTensorType, sub); + Value div = rewriter.create( + binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); + Value start = rewriter.create( + binder.getLoc(), rewriter.getType(), div); + Value end = rewriter.create( + binder.getLoc(), start, inputDimSize); + + SmallVector zerosShapeValues; + for (int j = 0; j < rank; j++) { + if (j == axes[i]) { + zerosShapeValues.push_back(outputDimSize); + } else { + Value dimSize = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(j))); + zerosShapeValues.push_back(dimSize); + } + } + Value zerosShapeList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + zerosShapeValues); + Value zeros = rewriter.create( + binder.getLoc(), opType, zerosShapeList, none, none, none, + none); + input = rewriter.create( + binder.getLoc(), opType, zeros, input, axis, start, end, + cstOne); + } + } + + rewriter.replaceOp(binder.op, input); + return success(); + }); patterns.onOp( "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // https://onnx.ai/onnx/operators/onnx__Clip.html @@ -754,7 +960,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::lowest())) return failure(); auto minSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); min = rewriter.create( binder.getLoc(), resultType, minSplatAttr); @@ -765,7 +971,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::numeric_limits::max())) return failure(); auto maxSplatAttr = SplatElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDtype), + resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); max = rewriter.create( binder.getLoc(), resultType, maxSplatAttr); @@ -846,7 +1052,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Concat", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; SmallVector tensors; int64_t dim; @@ -878,7 +1084,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_float") && !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getFloatAttr(dtype, floatValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -889,7 +1095,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.op->hasAttr("torch.onnx.value_int") && !binder.s64IntegerAttr(intValue, "value_int", 0)) { auto splatAttr = - SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + SplatElementsAttr::get(resultType.toBuiltinTensor(), rewriter.getIntegerAttr(dtype, intValue)); rewriter.replaceOpWithNewOp( binder.op, resultType, splatAttr); @@ -897,8 +1103,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } if (DenseResourceElementsAttr attr = - binder.op->getAttr("torch.onnx.value") - .dyn_cast_or_null()) { + dyn_cast_or_null( + binder.op->getAttr("torch.onnx.value"))) { // Bytes are stored in little endian order. Big endian support will // require swizzling. if (!Endian::little) { @@ -909,16 +1115,25 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto ty = cast(attr.getType()); ElementsAttr denseAttr; - auto ptr = attr.getRawHandle().getBlob()->getData(); + auto ptr = attr.getRawHandle().getBlob(); + if (!ptr) { + denseAttr = DenseResourceElementsAttr::get( + ty, "__onnx_constant_not_found_possibly_due_to_being_elided__", + AsmResourceBlob()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, denseAttr); + return success(); + } + auto data = ptr->getData(); if (cast(attr.getType()).getElementType().isInteger(1)) { llvm::SmallVector newContents; - for (auto val : ptr) { + for (auto val : data) { APInt apval(1, val); newContents.push_back(apval); } denseAttr = DenseElementsAttr::get(ty, newContents); } else { - denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + denseAttr = DenseElementsAttr::getFromRawBuffer(ty, data); } rewriter.replaceOpWithNewOp( @@ -926,8 +1141,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); } - if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value") - .dyn_cast_or_null()) { + if (ElementsAttr attr = dyn_cast_or_null( + binder.op->getAttr("torch.onnx.value"))) { rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -940,8 +1155,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( for (auto intVal : intValues) { apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); } - auto attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(dtype), apValues); + auto attr = + DenseElementsAttr::get(resultType.toBuiltinTensor(), apValues); rewriter.replaceOpWithNewOp( binder.op, resultType, attr); return success(); @@ -950,16 +1165,132 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); }); patterns.onOp( - "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - std::string autoPad; - if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + "Col2Im", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, blockShape, imageShape; + SmallVector dilations, strides, pads; + + // TODO: The length of dilations should be len(imageShape), and the same + // goes for strides. The length of pads should be 2 * len(imageShape). + // But, as at the moment we are only supporting 3D or 4D input, + // len(imageShape) must necessarily be 2, hence the lengths of the + // default values. + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(imageShape, 1) || + binder.tensorOperandAtIndex(blockShape, 2) || + binder.tensorResultType(resultType) || + binder.s64IntegerArrayAttr(dilations, "dilations", + SmallVector{1, 1}) || + binder.s64IntegerArrayAttr(strides, "strides", + SmallVector{1, 1}) || + binder.s64IntegerArrayAttr(pads, "pads", + SmallVector{0, 0, 0, 0})) return failure(); - if (autoPad != "NOTSET") { - // TODO: Add support for `auto_pad` != "NOTSET" + + auto imageShapeTy = cast(imageShape.getType()); + auto imageShapeSizes = imageShapeTy.getSizes(); + + auto blockShapeTy = cast(blockShape.getType()); + auto blockShapeSizes = blockShapeTy.getSizes(); + + // Check that neither imageShape nor blockShape have dynamic shapes. + if (imageShapeSizes[0] == Torch::kUnknownSize || + blockShapeSizes[0] == Torch::kUnknownSize) { return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); + binder.op, + "Dynamic shapes are not allowed for imageShape and blockShape"); } + // TODO: Add support for 5D input tensors. + if (imageShapeSizes[0] != 2) { + return rewriter.notifyMatchFailure( + binder.op, "Expected length of imageShape to be equal to 2"); + } + if (blockShapeSizes[0] != 2) { + return rewriter.notifyMatchFailure( + binder.op, "Expected length of blockShape to be equal to 2"); + } + if (dilations.size() != 2) { + return rewriter.notifyMatchFailure( + binder.op, "Expected length of dilations to be equal to 2"); + } + if (strides.size() != 2) { + return rewriter.notifyMatchFailure( + binder.op, "Expected length of strides to be equal to 2"); + } + + // TODO: Disable this check and add support for different + // paddings on lower and higher ends of each axis. + // Because we have already checked that imageShape has 2 elements, + // we can safely assume that len(padding) will be 4. + if (pads[0] != pads[2] || pads[1] != pads[3]) + return rewriter.notifyMatchFailure( + binder.op, "padding on the lower end and the higher end " + "on each axis should be the same"); + + // Since we know that the padding on the lower end and the higher + // end on each axis is the same, we can reduce the size of the + // padding list, and filter out the duplicate elements. + // (Also, Torch::AtenCol2imOp requires len(padding) to be 2). + SmallVector padOnEachAxis = {pads[0], pads[1]}; + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value paddingList = + createConstantIntList(binder, rewriter, padOnEachAxis); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + + // Index the imageShape and blockShape tensors, as AtenCol2imOp expects + // them to be int lists. + auto select = [&](Value v, Value k, + Torch::ValueTensorType ty) -> Value { + Value kTensor = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get( + binder.op->getContext(), ArrayRef{1}, + rewriter.getIntegerType(64, /*signed*/ 1)), + k); + + auto sel = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, + ty.getOptionalDtype()), + v, zero, kTensor); + Value item = rewriter.create( + binder.getLoc(), rewriter.getType(), sel); + return item; + }; + + SmallVector imageShapeContainer, blockShapeContainer; + for (int64_t i = 0; i < imageShapeSizes[0]; ++i) { + Value k = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + + // Passing in the shapeType of each of these tensors avoids + // repeated casts, as these have already been calculated. + imageShapeContainer.push_back(select(imageShape, k, imageShapeTy)); + blockShapeContainer.push_back(select(blockShape, k, blockShapeTy)); + } + + Value imageShapeAsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + imageShapeContainer); + Value blockShapeAsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + blockShapeContainer); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, imageShapeAsList, blockShapeAsList, + dilationsList, paddingList, stridesList); + return success(); + }); + patterns.onOp( + "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -984,14 +1315,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unsupported conversion: kernel_shape list size should have " "number of values equal to weight_rank - 2"); - } else { - for (unsigned i = 0; i < kernelShape.size(); i++) { - if (weightShape[i + 2] != kernelShape[i]) { - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: kernel_shape value " - "should be equal to the weight tensor shape"); - } - } } } @@ -1009,20 +1332,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( defaultStrides.push_back(1); defaultDilations.push_back(1); } - // Padding for the beginning and ending along each spatial axis, it can - // take any value greater than or equal to 0. The value represent the - // number of pixels added to the beginning and end part of the - // corresponding axis. pads format should be as follow [x1_begin, - // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added - // at the beginning of axis i and xi_end, the number of pixels added at - // the end of axis i. - if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { - return failure(); - } - if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { - return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); - } if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) { return failure(); @@ -1039,71 +1348,181 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); } + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + auto inputTensorType = cast(input.getType()); + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (autoPad == "NOTSET") { + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + } else if (autoPad == "VALID") { + padding = defaultPadding; + } else { + const bool isSameLower = autoPad == "SAME_LOWER"; + const unsigned spatialRank = rank - 2; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatialRank); + for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { + if (weightShape[dimIdx + 2] == Torch::kUnknownSize || + inputShape[dimIdx + 2] == Torch::kUnknownSize) + return rewriter.notifyMatchFailure( + binder.op, + "expected weight and input tensor to have static shape"); + const int64_t dilatedKernelSize = + dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatialRank + dimIdx] = totalPad - padding[dimIdx]; + } + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; + Value paddedInput = input; + Value paddingList; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + loc, rewriter.getI64IntegerAttr(i))); } + paddingList = rewriter.create( + loc, + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstPadding); } else { + // ONNX offers pads in the format listing all starting dims, then all + // ending dims, e.g. {t, l, b, r} for conv2d. Torch by default accepts + // only starting dims, e.g. {t, l}. However, we can support padding at + // the beginning and end of each dimension by first performing + // torch.nn.functional.pad on the input. But this requires the pad + // values to be rearranged since torch pad() takes pads in the order + // rightmost dim start and end, then next to last, and so on, e.g. {l, + // r, t, b}. + bool matchedPads = true; for (unsigned i = 0; i < padding.size() / 2; i++) { if (padding[i] != padding[i + (padding.size() / 2)]) { - // TODO: Add support for different padding values for the - // beginning and ending along each spatial axis - return rewriter.notifyMatchFailure( - binder.op, - "unsupported conversion: padding values for the beginning " - "and ending along each spatial axis must be equal"); + matchedPads = false; + break; } - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + if (matchedPads) { + for (unsigned i = 0; i < padding.size() / 2; i++) { + cstPadding.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(padding[i]))); + } + paddingList = rewriter.create( + loc, + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + cstPadding); + } else { + SmallVector padsRearrange; + SmallVector inputPaddingList; + for (uint32_t i = 0; i < padding.size() / 2; i++) { + padsRearrange.emplace_back(rewriter.create( + loc, rewriter.getI64IntegerAttr( + padding[padding.size() / 2 - i - 1]))); + padsRearrange.emplace_back(rewriter.create( + loc, + rewriter.getI64IntegerAttr(padding[padding.size() - i - 1]))); + inputPaddingList.emplace_back( + rewriter.create( + loc, rewriter.getI64IntegerAttr(0))); + } + // The conv op itself will have no padding since the actual padding + // is performed using the torch.pad preceding it. + paddingList = rewriter.create( + loc, + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + inputPaddingList); + Value padsSizeList = + rewriter + .create( + loc, + Torch::ListType::get( + rewriter.getType()), + padsRearrange) + .getResult(); + Value modeVal = rewriter.create( + loc, rewriter.getStringAttr("constant")); + Value constantValue; + + if (isa(inputTensorType.getDtype())) + constantValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + if (isa(inputTensorType.getDtype())) + constantValue = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0f)); + // Pad output shape must be computed explicitly from the pad values + SmallVector newInputShape(inputTensorType.getSizes()); + for (uint32_t i = 0; i < padding.size() / 2; i++) { + newInputShape[2 + i] += + padding[i] + padding[(padding.size() / 2) + i]; + } + auto padTy = rewriter.getType( + newInputShape, inputTensorType.getDtype()); + paddedInput = rewriter.create( + loc, padTy, input, padsSizeList, modeVal, constantValue); } } for (int64_t i : dilations) { cstDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + loc, rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + loc, rewriter.getI64IntegerAttr(i))); } Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + loc, rewriter.getI64IntegerAttr(0)); cstOutputPadding = {cstZero, cstZero}; - Value paddingList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - cstPadding); Value dilationsList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); Value stridesList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value outputPaddingList = rewriter.create( - binder.getLoc(), + loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); - Value transposed = - rewriter.create(binder.getLoc(), false); + Value transposed = rewriter.create(loc, false); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { - bias = rewriter.create(binder.getLoc()); + bias = rewriter.create(loc); } Value cstGroup = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(group)); + loc, rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( - binder.op, resultType, input, weight, bias, stridesList, + binder.op, resultType, paddedInput, weight, bias, stridesList, paddingList, dilationsList, transposed, outputPaddingList, cstGroup); return success(); @@ -1265,20 +1684,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); - if (autoPad != "NOTSET") { - // TODO: Add support for `auto_pad` != "NOTSET" - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); - } - SmallVector outputShape; - if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) - return failure(); - if (outputShape.size()) { - // TODO: Add support for non-None output_shape value. - return rewriter.notifyMatchFailure( - binder.op, - "unsupported conversion: output_shape should be absent"); - } Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1312,6 +1717,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } } } + } else { + for (unsigned i = 0; i < weightShape.size() - 2; i++) { + kernelShape.push_back(weightShape[i + 2]); + } } // Determine the rank of input tensor. @@ -1321,7 +1730,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; - SmallVector padding, strides, dilations, outputPadding; + SmallVector padding, strides, dilations, outputPadding, + outputShape; SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; for (unsigned i = 0; i < rank - 2; i++) { @@ -1337,13 +1747,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added // at the beginning of axis i and xi_end, the number of pixels added at // the end of axis i. - if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { - return failure(); - } - if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { - return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); - } if (binder.s64IntegerArrayAttr(dilations, "dilations", defaultDilations)) { return failure(); @@ -1369,7 +1772,60 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "output_padding list size does not match the number of axes"); } + auto inputTensorType = cast(input.getType()); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + ArrayRef inputShape = inputTensorType.getSizes(); + if (autoPad == "VALID") { + // Zero padding. + padding = defaultPadding; + } else if (autoPad == "NOTSET") { + // Explicit padding; read pads with defaults. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) + return failure(); + } else { // autopad == SAME_UPPER or SAME_LOWER + // Auto-padding; output_shape defaults to input_shape * strides. + SmallVector defaultOutputShape; + for (unsigned i = 0; i < rank - 2; i++) { + defaultOutputShape.push_back(inputShape[2 + i] * strides[i]); + } + if (binder.s64IntegerArrayAttr(outputShape, "output_shape", + defaultOutputShape)) + return failure(); + SmallVector paddingEnd; + for (unsigned i = 0; i < rank - 2; i++) { + int64_t totalPadding = + strides[i] * (inputShape[2 + i] - 1) + outputPadding[i] + + ((kernelShape[i] - 1) * dilations[i] + 1) - outputShape[i]; + if (totalPadding % 2) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis. + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: the combination of stride, " + "input_shape, kernel_shape, dilation, output_padding and " + "output_shape caused auto-padding to produce asymmetric " + "padding which isn't currently supported."); + } + int64_t half = totalPadding / 2; + int64_t remainder = totalPadding - half; + if (autoPad == "SAME_UPPER") { + padding.push_back(half); + paddingEnd.push_back(remainder); + } else { + padding.push_back(remainder); + paddingEnd.push_back(half); + } + } + padding.insert(padding.end(), paddingEnd.begin(), paddingEnd.end()); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } SmallVector cstPadding, cstStrides, cstDilations, cstOutputPadding; if (padding.size() != 2 * (rank - 2)) { @@ -1634,6 +2090,151 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, transposedInput, reshapeSizesList); return success(); }); + patterns.onOp( + "DeformConv", 19, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + auto loc = binder.getLoc(); + + // get operands + llvm::SmallVector operands; + Torch::ValueTensorType resultType; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType)) + return failure(); + if (operands.size() < 3 || operands.size() > 5) + return failure(); + auto inputType = + dyn_cast(operands[0].getType()); + if (!inputType || !inputType.hasSizes() || + inputType.getSizes().size() != 4) + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: DeformConv with input rank != 4"); + unsigned rank = inputType.getSizes().size(); + auto weightType = + dyn_cast(operands[1].getType()); + if (!weightType || !weightType.hasSizes()) + return failure(); + auto offsetType = + dyn_cast(operands[2].getType()); + if (!offsetType || !offsetType.hasSizes()) + return failure(); + + // get attributes + SmallVector dilations, kernelShape, pads, strides; + SmallVector defaultDilations(rank - 2, 0); + SmallVector defaultPads(2 * (rank - 2), 0); + SmallVector defaultStrides(rank - 2, 1); + int64_t group, offsetGroup; + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations) || + binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}) || + binder.s64IntegerArrayAttr(pads, "pads", defaultPads) || + binder.s64IntegerArrayAttr(strides, "strides", defaultStrides) || + binder.s64IntegerAttr(group, "group", 1) || + binder.s64IntegerAttr(offsetGroup, "offset_group", 1)) + return failure(); + + for (unsigned i = 0; i < rank - 2; i++) { + if (pads[i] != pads[rank + i - 2]) + return rewriter.notifyMatchFailure( + binder.op, "unsupported: asymmetric padding"); + } + + // Identify and assign names to operands + Value input, weight, offset, bias, mask; + bool useMask = false; + input = operands[0]; + weight = operands[1]; + offset = operands[2]; + if (operands.size() == 4) { + auto unknownOpdRank = Torch::getTensorRank(operands[3]); + if (!unknownOpdRank) + return failure(); + if (*unknownOpdRank == 1) + bias = operands[3]; + else if (*unknownOpdRank == rank) { + mask = operands[3]; + useMask = true; + } else + llvm_unreachable("onnx.DeformConv: optional 4th operand of " + "unexpected rank encountered"); + } + if (operands.size() == 5) { + bias = operands[3]; + mask = operands[4]; + useMask = true; + } + + // assign default operand values if necessary + ArrayRef weightSizes = weightType.getSizes(); + ArrayRef offsetSizes = offsetType.getSizes(); + if (!bias) { + int64_t outputChannels = weightSizes[0]; + SmallVector biasShape(1, outputChannels); + Value biasShapeList = mlir::torch::onnx_c::createConstantIntList( + binder, rewriter, biasShape); + Value cstZero = Torch::getConstantWithGivenDtypeAndValue( + rewriter, loc, 0.0f, inputType.getDtype()); + bias = + Torch::createInitTensor(rewriter, loc, + rewriter.getType( + biasShape, inputType.getDtype()), + cstZero, biasShapeList); + } + if (!mask) { + int64_t batchSize = inputType.getSizes()[0]; + int64_t kernelHeight = weightSizes[2]; + int64_t kernelWidth = weightSizes[3]; + int64_t outputHeight = offsetSizes[2]; + int64_t outputWidth = offsetSizes[3]; + int64_t maskDimOne = offsetGroup * kernelHeight * kernelWidth; + SmallVector maskShape( + {batchSize, maskDimOne, outputHeight, outputWidth}); + Value cstOne = Torch::getConstantWithGivenDtypeAndValue( + rewriter, loc, 1.0f, inputType.getDtype()); + Value maskShapeList = mlir::torch::onnx_c::createConstantIntList( + binder, rewriter, maskShape); + mask = + Torch::createInitTensor(rewriter, loc, + rewriter.getType( + maskShape, inputType.getDtype()), + cstOne, maskShapeList); + } + + // get attributes as constant values + SmallVector dilationValues, padValues, strideValues; + for (auto i : dilations) + dilationValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + for (auto i : pads) + padValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + for (auto i : strides) + strideValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + Value groupValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(group)); + Value offsetGroupValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(offsetGroup)); + Value useMaskValue = rewriter.create( + loc, rewriter.getBoolAttr(useMask)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, offset, mask, bias, + strideValues[0], strideValues[1], padValues[0], padValues[1], + dilationValues[0], dilationValues[1], groupValue, offsetGroupValue, + useMaskValue); + return success(); + }); + patterns.onOp( + "Det", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp(binder.op, + resultType, input); + return success(); + }); patterns.onOp( "DequantizeLinear", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1643,50 +2244,73 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); + auto loc = binder.getLoc(); Value operand = operands[0]; Value scale = operands[1]; Value zeropoint = operands[2]; auto operandTy = cast(operand.getType()); + auto operandETy = operandTy.getDtype(); auto scaleTy = dyn_cast(scale.getType()); if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); if (!resultType.hasDtype()) return rewriter.notifyMatchFailure(binder.op, "requires known result dtype"); - if (scaleTy.getSizes().size() == 0 || - (scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1)) { - Type qTy = operandTy.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { - return rewriter.notifyMatchFailure(binder.op, - "unsupported result dtype"); - } - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); - scale = rewriter.create( - binder.getLoc(), rewriter.getType(), scale); - zeropoint = rewriter.create( - binder.getLoc(), rewriter.getType(), zeropoint); - - auto quantize = - rewriter.create( - binder.getLoc(), qTensorTy, operand, scale, zeropoint); - rewriter.replaceOpWithNewOp( - binder.op, resultType, quantize); + bool rank0 = scaleTy.getSizes().size() == 0; + bool length1 = + scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1; + + if (!rank0 && !length1) + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: non-scalar scale"); + auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy); + if (!qTensorTy) { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + scale = rewriter.create( + loc, rewriter.getType(), scale); + + bool fpOperand = isa(operandETy); + Type zeropointTy = rewriter.getType(); + if (fpOperand) + zeropointTy = rewriter.getType(); + + zeropoint = + rewriter.create(loc, zeropointTy, zeropoint); + + if (fpOperand) { + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + auto tyVal = Torch::getScalarTypeForType(resultType.getDtype()); + Value tyConst = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(tyVal))); + Value toDtype = rewriter.create( + loc, resultType, operand, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + + Value one = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + Value sub = rewriter.create( + loc, resultType, toDtype, zeropoint, one); + rewriter.replaceOpWithNewOp( + binder.op, resultType, sub, scale); return success(); } - return rewriter.notifyMatchFailure(binder.op, - "unimplemented: non-scalar scale"); + auto quantize = + rewriter.create( + loc, qTensorTy, operand, scale, zeropoint); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); }); patterns.onOp("Div", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1877,7 +2501,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value input; float alpha; if (binder.tensorOperand(input) || - binder.f32FloatAttr(alpha, "alpha") || + binder.f32FloatAttr(alpha, "alpha", 1.0) || binder.tensorResultType(resultType)) return failure(); Value cstAlpha = rewriter.create( @@ -1928,7 +2552,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); auto shapeSizes = shapeType.getSizes(); - int64_t dataRank = dataType.getSizes().size(); + ArrayRef dataShape = dataType.getSizes(); + int64_t dataRank = dataShape.size(); int64_t shapeRank = shapeSizes.size(); if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) return failure(); @@ -1950,22 +2575,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // we are using torch implementation Torch::AtenBroadcastToOp which // takes list of int for (int i = 0; i < shapeSizes[0]; i++) { + // extract dim from shape Value selectIndex = rewriter.create( loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); Value extract = rewriter.create( loc, selectResultType, shape, zero, selectIndex); - Value dim = rewriter.create( + Value selectDim = rewriter.create( loc, rewriter.getType(), extract); - - if (i + rankDifference >= 0) { + // compute dim to pass to broadcast op. For non-broadcastable dims, + // pass -1 + Value dim; + if (i + rankDifference >= 0 && dataShape[i + rankDifference] != 1) { + // 1. if dataShape[i + rankDiff] > 1, then this cannot be + // broadcasted + // 2. we will explicitly disallow broadcasting dynamic dims that are + // secretly 1. + dim = rewriter.create(loc, -1); + // Assert dataShape[i + rankDiff] >= selectDim. If both are + // constant, this should fold out. Value iv = rewriter.create(loc, i + rankDifference); auto sz = rewriter.create( loc, rewriter.getType(), data, iv); - dim = rewriter.create(loc, dim, sz); + Value gtSelect = + rewriter.create(loc, sz, selectDim); + rewriter.create( + loc, gtSelect, + rewriter.getStringAttr( + "onnx.Expand input has a dim that is not statically 1; " + "expected this dim >= dim provided shape.")); + } else { + // 1. excess selectDims get included in broadcast (shapeSizes[0] > + // dataRank) + // 2. selectDims which correspond to dataShape == 1 get included in + // broadcast + dim = selectDim; } - dimList.push_back(dim); } Value dimValueList = rewriter.create( @@ -2072,7 +2718,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Flatten", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // Flatten means to partition the input tensor's dimensions // into a "left range" spanning 0 to axis - 1 and a "right range" // spanning axis to rank - 1. Each range is then collapsed @@ -2209,14 +2855,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Get fill_value if it is present. // Assumption : resultDType and value attr type match. auto attr = binder.op->getAttr("torch.onnx.value"); - auto resultDType = resultType.getDtype(); // Extract the fill value and dtype // ONNX requires value attr to be a tensor + Value splatvalue; + // if no value attr is provided, default is 0.0 float value if (!attr) { - attr = DenseElementsAttr::get( - resultType.toBuiltinTensor().clone(resultDType), - rewriter.getFloatAttr(resultDType, 0.0)); + splatvalue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0)); } // If its a dense resource attr we need to convert to a dense type: @@ -2237,19 +2883,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } Attribute splattr; - if (isa(attr)) { + if (attr && isa(attr)) { auto denseAttr = cast(attr); splattr = denseAttr.getSplatValue(); } - if (!isa(splattr)) { + if (splattr && !isa(splattr)) { return rewriter.notifyMatchFailure( binder.op, "`value` attr tensor only supports types int and float for now."); } - Value splatvalue; - if (auto intattr = dyn_cast(splattr)) { + if (auto intattr = dyn_cast_or_null(splattr)) { IntegerType intty = cast(intattr.getType()); int64_t value; if (intty.isUnsignedInteger()) { @@ -2263,7 +2908,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( rewriter.create(binder.getLoc(), value); } - if (auto fpattr = dyn_cast(splattr)) + if (auto fpattr = dyn_cast_or_null(splattr)) splatvalue = rewriter.create( binder.getLoc(), rewriter.getF64FloatAttr(fpattr.getValueAsDouble())); @@ -2283,9 +2928,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); Type listElemType = - tensors[0] - .getType() - .cast() + cast(tensors[0].getType()) .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); @@ -2391,4 +3034,95 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); + + patterns.onOp( + "DFT", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value inTensor, dftLength, axis; + Torch::ValueTensorType resultType; + int64_t inverse, onesided; + if (binder.tensorOperandAtIndex(inTensor, 0) || + binder.s64IntegerAttr(inverse, "inverse", 0) || + binder.s64IntegerAttr(onesided, "onesided", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "Input Tensor / attrs / resultType bind failed"); + if (!binder.tensorOperandAtIndex(dftLength, 1)) { + // Convert to int and pass as n + dftLength = rewriter.create( + binder.getLoc(), rewriter.getType(), dftLength); + } else { + // Default for torch is None + dftLength = rewriter.create(binder.getLoc()); + } + // Default is same for onnx and torch + if (!binder.tensorOperandAtIndex(axis, 2)) { + // convert to int and pass to dims + axis = rewriter.create( + binder.getLoc(), rewriter.getType(), axis); + } else { + // Default in torch is -1 and onnx is -2 (since -1 is for real / img) + axis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(-2)); + } + + if (onesided == 1) + return rewriter.notifyMatchFailure(binder.op, + "Unsupported option : onesided"); + // norm default string attr + Value norm = rewriter.create( + binder.getLoc(), rewriter.getStringAttr(Twine("backward"))); + // Convert from [....., 2] complex number repr for fft consumption. + Torch::ValueTensorType inType = + binder.toValidTensorType(inTensor.getType()); + int64_t lastIndex = inType.getSizes().back(); + if (lastIndex != 1 && lastIndex != 2) + return rewriter.notifyMatchFailure( + binder.op, + "Expected input tensor to have dims [..., 1] or [..., 2]"); + + // concat with zeros to make it [..., 2] + Value inForComplexVal = inTensor; + ArrayRef inForComplexSizes = inType.getSizes().drop_back(); + if (lastIndex == 1) { + Value constZeroVal = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0)); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value constZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value padSizeList = + rewriter + .create( + binder.getLoc(), + Torch::ListType::get(rewriter.getType()), + SmallVector({constZero, constOne})) + .getResult(); + Value modeVal = rewriter.create( + binder.getLoc(), rewriter.getStringAttr("constant")); + SmallVector resSize(inForComplexSizes); + resSize.push_back(2); + inForComplexVal = rewriter.create( + binder.getLoc(), + inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()), + inTensor, padSizeList, modeVal, constZeroVal); + } + Type inComplexTensorType = Torch::ValueTensorType::get( + binder.op->getContext(), inForComplexSizes, + mlir::ComplexType::get(inType.getDtype())); + Value inComplexTensor = rewriter.create( + binder.getLoc(), inComplexTensorType, inForComplexVal); + Value ftOp; + if (inverse == 0) { + ftOp = rewriter.create( + binder.getLoc(), inComplexTensorType, inComplexTensor, + /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); + } else { + ftOp = rewriter.create( + binder.getLoc(), inComplexTensorType, inComplexTensor, + /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); + } + rewriter.replaceOpWithNewOp(binder.op, + resultType, ftOp); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 64ffd2378feb..3db33aee1f1c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -46,29 +46,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); - Value constBeta = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta - Value alpha_x_plus_beta = rewriter.create( - binder.getLoc(), resultType, tensorOperand, constBeta, - /*alpha=*/constAlpha); + Value alphaMulX = rewriter.create( + binder.getLoc(), resultType, tensorOperand, constAlpha); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + Value alphaMulXPlusBeta = rewriter.create( + binder.getLoc(), resultType, alphaMulX, constBeta, + /*alpha=*/constOne); // Expression: min(1, alpha * x + beta) - Value constantOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), - resultType, constantOne); + Value oneTensor = + createRank0Tensor(rewriter, binder.getLoc(), resultType, constOne); Value minExpression = rewriter.create( - binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); + binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta); // Expression: max(0, min(1, alpha * x + beta)) - Value constantZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), - resultType, constantZero); + Value constZero = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0)); + Value zeroTensor = + createRank0Tensor(rewriter, binder.getLoc(), resultType, constZero); rewriter.replaceOpWithNewOp( binder.op, resultType, zeroTensor, minExpression); return success(); @@ -167,6 +169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( alignCorners); return success(); }); + patterns.onOp("GRU", 1, onnx_c::OnnxGruExpander); patterns.onOp( "If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value conditionTensor; @@ -176,8 +179,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto conditionType = - conditionTensor.getType().cast(); - if (!conditionType || conditionType.getSizes().size() != 1) + cast(conditionTensor.getType()); + if (!conditionType || conditionType.getSizes().size() > 1) return rewriter.notifyMatchFailure( binder.op, "condition must have one single element per " "https://onnx.ai/onnx/operators/onnx__If.html"); @@ -208,15 +211,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( inlineIfCase(*thenRegion, primIfOp.getThenRegion()); inlineIfCase(*elseRegion, primIfOp.getElseRegion()); - auto replaceTerminator = [&](Region ®ion) { + auto replaceTerminator = [&](Region ®ion) -> LogicalResult { PatternRewriter::InsertionGuard guard(rewriter); Operation *terminator = region.front().getTerminator(); rewriter.setInsertionPoint(terminator); - rewriter.replaceOpWithNewOp( - terminator, terminator->getOperands()); + + // cast result shape if there is static/dynamic difference + llvm::SmallVector terOperands = terminator->getOperands(); + if (terOperands.size() != resultTypes.size()) + return failure(); + for (size_t i = 0; i < terOperands.size(); i++) { + mlir::Type terType = terOperands[i].getType(); + int64_t terOpRank = + dyn_cast(terType).getSizes().size(); + int64_t resRank = dyn_cast(resultTypes[i]) + .getSizes() + .size(); + if (terOpRank != resRank) + return failure(); + if (terType != resultTypes[i]) { + Value cast = rewriter.create( + binder.getLoc(), resultTypes[i], terOperands[i]); + terOperands[i] = cast; + } + } + + rewriter.replaceOpWithNewOp(terminator, + terOperands); + return success(); }; - replaceTerminator(primIfOp.getThenRegion()); - replaceTerminator(primIfOp.getElseRegion()); + if (failed(replaceTerminator(primIfOp.getThenRegion())) || + failed(replaceTerminator(primIfOp.getElseRegion()))) + return rewriter.notifyMatchFailure(binder.op, + "terminator replace failure"); rewriter.replaceOp(binder.op, primIfOp.getResults()); return success(); @@ -257,6 +284,159 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Loop", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Get all operands (maxTripCount, cond, ....inits....) + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || operands.size() == 0 || + binder.getNumOperands() < 2) { + return rewriter.notifyMatchFailure(binder.op, + "Failed to get required operands"); + } + + llvm::SmallVector operandTypeVec; + if (binder.tensorOperandTypes(operandTypeVec) || + operandTypeVec.size() == 0) { + return rewriter.notifyMatchFailure(binder.op, + "Failed to get operandTypes"); + } + + Region *loopBodyIn; + if (binder.getRegionAtIndex(loopBodyIn, 0)) { + return rewriter.notifyMatchFailure(binder.op, + "Failed getting LoopBody Region"); + } + + // MaxTripCount - tensor int64 scalar (or empty) + Value maxTripCountTensor = operands[0]; + auto maxTripCountInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + maxTripCountTensor); + + // Condition - tensor bool scalar (or empty) + Value conditionTensor = operands[1]; + auto conditionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + conditionTensor); + auto conditionBool = rewriter.create( + binder.getLoc(), rewriter.getType(), conditionInt); + // To be used for "for like" loop case + auto constBoolTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + + // Others (if present) - variadic (can be tensors and scalar values) + if (binder.getNumOperands() > 2) { + operandTypeVec.erase(operandTypeVec.begin(), + operandTypeVec.begin() + 2); + operands.erase(operands.begin(), operands.begin() + 2); + } + + auto getOpName = [](Operation *op) -> std::string { + std::string name = op->getName().getStringRef().str(); + if (name != "torch.operator") + return name; + // for unconverted onnx ops + return mlir::dyn_cast(op->getAttr("name")) + .getValue() + .str(); + }; + + // PrimLoop Op expectes inputCondition to be boolConstantTrue + // to decide if the loopOp is `forlike`. Use loopIsForLike to + // ensure appropriate inputCondition is set + // Case 1 : loopCondInp -> identity -> terminator(loopCondOut) + bool loopIsForLike = false; + auto case1ForLike = [&getOpName](Region *loopBody) -> bool { + Value onnxLoopBodyCondIn = loopBody->front().getArgument(1); + if (!onnxLoopBodyCondIn.hasOneUse()) + return false; + Operation *inpCondUser = *onnxLoopBodyCondIn.getUsers().begin(); + if (getOpName(inpCondUser) != "onnx.Identity") { + return false; + } + if (!inpCondUser->hasOneUse() || + getOpName(*(inpCondUser->getUsers().begin())) != + "torch.operator_terminator") + return false; + return true; + }; + loopIsForLike = case1ForLike(loopBodyIn); + + Value loopInitCondition = + loopIsForLike ? constBoolTrue : conditionBool.getResult(); + auto loc = binder.getLoc(); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + auto loop = b.create( + TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition, + ValueRange(operands)); + + rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(), + loop.getRegion().begin()); + + // primLoopOp loopBody expects torch.int as first arg + // insert torch.int arg in loop body, convert to tensor, + // replace all uses of old arg, delete old arg. + auto loopVarArg = loop.getRegion().front().getArgument(0); + // insert new Arg + loop.getRegion().front().insertArgument( + 0U, rewriter.getType(), binder.getLoc()); + auto newLoopVarArg = loop.getRegion().front().getArgument(0); + + // convert int arg to tensor of original Type + rewriter.setInsertionPointToStart(&loop.getRegion().front()); + Value loopVarVal = BlockArgument::Value(loopVarArg); + auto newTensor = rewriter.create( + loop.getRegion().op_begin()->getLoc(), loopVarVal.getType(), + newLoopVarArg); + + loopVarArg.replaceAllUsesWith(newTensor); + loop.getRegion().eraseArgument(1); + + // primLoopOp loopBody has no condition arg + auto condArg = loop.getRegion().front().getArgument(1); + if (!condArg.use_empty()) + condArg.replaceAllUsesWith(conditionTensor); + + // replace terminator + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = loop.getRegion().front().getTerminator(); + rewriter.setInsertionPoint(terminator); + + // results - n loop carried dependencies and k scan outputs + // Fail when there are scanOutputs in onnxLoop (K>0); + // unsupported for now + if (terminator->getNumOperands() != + loop.getRegion().getNumArguments() - 1) { + return rewriter.notifyMatchFailure( + binder.op, "scanOutputs in loop body unsupported"); + } + + // Get remaining operands from onnxLoopBody's terminator Op + // these are all the loop carried dependencies in the loop body + auto terminatorOperands = terminator->getOperands(); + llvm::SmallVector remTerminatorOperands( + terminatorOperands.begin() + 1, terminatorOperands.end()); + Value terminatorCond; + if (loopIsForLike) { + terminatorCond = constBoolTrue; + } else { + // Only use when loop is not forlike + Value terminatorCondTensor = terminatorOperands[0]; + auto terminatorCondInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + terminatorCondTensor); + auto terminatorCondBool = rewriter.create( + binder.getLoc(), rewriter.getType(), + terminatorCondInt); + terminatorCond = terminatorCondBool.getResult(); + } + rewriter.replaceOpWithNewOp( + terminator, terminatorCond, remTerminatorOperands); + + loop.getRegion().eraseArgument(1); + rewriter.replaceOp(binder.op, loop); + return success(); + }); patterns.onOp("LSTM", 1, onnx_c::OnnxLstmExpander); patterns.onOp( "LogSoftmax", 13, @@ -352,7 +532,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rightDimsPrimList); return success(); }); - patterns.onOp("MatMul", 13, + patterns.onOp("MatMul", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; @@ -408,20 +588,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); - auto q = [&](Type qty) -> Type { - if (qty.isSignedInteger(8)) - return rewriter.getType(); - if (qty.isUnsignedInteger(8)) - return rewriter.getType(); - if (qty.isSignedInteger(32)) - return rewriter.getType(); - return {}; - }; + auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy); + auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy); - Type lhsQTy = rewriter.getType( - lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); - Type rhsQTy = rewriter.getType( - rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); + if (!lhsQTy || !rhsQTy) + return rewriter.notifyMatchFailure(binder.op, "failed to get qtype"); lhs = rewriter.create( binder.getLoc(), lhsQTy, lhs, scale, lhsZp); @@ -444,37 +615,528 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("NonZero", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) { - return failure(); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + + patterns.onOp( + "MelWeightMatrix", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + llvm::SmallVector operands; + Torch::ValueTensorType resultType; + int64_t output_dtype_attr; + if (binder.tensorOperands(operands, 5) || + binder.tensorResultType(resultType) || operands.size() != 5 || + binder.s64IntegerAttr(output_dtype_attr, "output_datatype", 1)) { + return failure(); + } + // operands sequence : + // num_mel_bins, dft_length, sample_rate -> int32/64 tensors + // lower_edge_hertz, upper_edge_hertz -> f16/32/64 + + // Need to backtrack the values of num_mel_bins and dft_length//2+1 from + // result shape since the inputs are tensors and we cannot know their + // values at compile time. if the result type does not contain static + // shapes, then the implementation will be unsupported. + if (!resultType.areAllSizesKnown()) + return rewriter.notifyMatchFailure( + binder.op, "Unknown result sizes, not supported."); + + ArrayRef resShape = resultType.getSizes(); + if (resShape.size() != 2) + return rewriter.notifyMatchFailure( + binder.op, + "Expected result rank to be 2, not supported for other ranks."); + + std::optional torchDTypeInt = + onnxDtypeIntToTorchDtypeInt(output_dtype_attr); + if (!torchDTypeInt.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, "conversion to given output dtype unsupported"); + } + + // Here Onwards all shapes will be computed using these sizes + int64_t numSpectrogramBinsInt = resShape[0]; + int64_t numMelBinsInt = resShape[1]; + Torch::ValueTensorType inputIntType = binder.toValidTensorType( + operands[0].getType()); // Since operands[0 / 1 / 2] will have the + // same int type. + Torch::ValueTensorType inputFloatType = binder.toValidTensorType( + operands[3].getType()); // Since operands[3 / 4] will have the same + // float type + + Value numMelBinsItem = + getItemOp(binder, rewriter, operands[0]); + Value sampleRateItem = + getItemOp(binder, rewriter, operands[2]); + Value lowerEdgeHzItem = + getItemOp(binder, rewriter, operands[3]); + Value upperEdgeHzItem = + getItemOp(binder, rewriter, operands[4]); + + // Helpers + ImplicitLocOpBuilder b(binder.getLoc(), rewriter); + auto ctx = binder.op->getContext(); + + // Recurring shapes + SmallVector unranked({}); + SmallVector shapeNMB({numMelBinsInt}); + SmallVector shape1xNMB({1, numMelBinsInt}); + SmallVector shapeNSB({numSpectrogramBinsInt}); + SmallVector shapeNSBx1({numSpectrogramBinsInt, 1}); + SmallVector shapeNSBxNMB( + {numSpectrogramBinsInt, numMelBinsInt}); + + // Recurring DTypes + Type inpFpDType = inputFloatType.getDtype(); + Type inpIntDType = inputIntType.getDtype(); + Type si32Ty = rewriter.getIntegerType(32, true); + Type f32Ty = rewriter.getF32Type(); + Type i1Ty = rewriter.getI1Type(); + + // Value constants + Value noneConst = b.create(); + Value zeroConst = + b.create(rewriter.getI64IntegerAttr(0)); + Value oneConst = + b.create(rewriter.getI64IntegerAttr(1)); + Value twoConst = + b.create(rewriter.getI64IntegerAttr(2)); + Value int32DTypeConst = + b.create(rewriter.getI64IntegerAttr(3)); + Value float32DTypeConst = + b.create(rewriter.getI64IntegerAttr(6)); + + Torch::ValueTensorType dftLenType = + Torch::ValueTensorType::get(ctx, unranked, inpIntDType); + Type freqBinsIntType = + Torch::ValueTensorType::get(ctx, shapeNMB, si32Ty); + Type freqBinsFltType = + Torch::ValueTensorType::get(ctx, shapeNMB, f32Ty); + + Value dftLengthDivTwoTensor = b.create( + dftLenType, operands[1], twoConst); + Value numSpectrogramBinsTensor = b.create( + dftLenType, dftLengthDivTwoTensor, oneConst, /*alpha =*/oneConst); + Value numSpectrogramBinsItem = getItemOp( + binder, rewriter, numSpectrogramBinsTensor); + + // From Ref Impl of Onnx.MelWeightMatrix: + // https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32 + // convert input Freq Hz to Mel + Value twoFiveNineFiveConst = + b.create(rewriter.getF64FloatAttr(2595)); + Value sevenHConst = + b.create(rewriter.getF64FloatAttr(700)); + Value tenConst = + b.create(rewriter.getF64FloatAttr(10)); + Value oneFltConst = + b.create(rewriter.getF64FloatAttr(1)); + Value LnToLog10Const = b.create( + rewriter.getF64FloatAttr(M_LOG10E)); + + Value lfDiv7Hfloat = + b.create(lowerEdgeHzItem, sevenHConst); + Type freqType = Torch::ValueTensorType::get(ctx, unranked, inpFpDType); + Value lfDiv7H = + b.create(freqType, lfDiv7Hfloat); + Value lfDiv7HAdd1 = b.create( + freqType, lfDiv7H, oneConst, /*alpha =*/oneConst); + Value lfDiv7HAdd1Ln = b.create(freqType, lfDiv7HAdd1); + Value lfDiv7HAdd1Log10 = b.create( + freqType, lfDiv7HAdd1Ln, LnToLog10Const); + + Value lfMel = b.create( + freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst); + + Value hfDiv7Hfloat = + b.create(upperEdgeHzItem, sevenHConst); + Value hfDiv7H = + b.create(freqType, hfDiv7Hfloat); + Value hfDiv7HAdd1 = b.create( + freqType, hfDiv7H, oneConst, /*alpha =*/oneConst); + Value hfDiv7HAdd1Ln = b.create(freqType, hfDiv7HAdd1); + Value hfDiv7HAdd1Log10 = b.create( + freqType, hfDiv7HAdd1Ln, LnToLog10Const); + + Value hfMel = b.create( + freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst); + + Value hfSubLf = b.create( + hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst); + Value numMelBinsPlus2 = + b.create(numMelBinsItem, twoConst); + Value melStep = b.create( + hfSubLf.getType(), hfSubLf, numMelBinsPlus2); + + Value lowBinsInit = b.create( + freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); + + Value centerBinsInit = b.create( + freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); + + Value highBinsInit = b.create( + freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); + + // Common values used in conversion + Value dftLenPlusOne = b.create( + dftLenType, operands[1], oneConst, /*alpha=*/oneConst); + Value dftLenPlusOneItem = + getItemOp(binder, rewriter, dftLenPlusOne); + Value falseConst = b.create(false); + Torch::ValueTensorType unsqueezeBinsResType = + Torch::ValueTensorType::get(ctx, shape1xNMB, si32Ty); + + // Low bins Mel to hz + Value lowBinsMulMelStep = b.create( + freqBinsFltType, lowBinsInit, melStep); + Value lowBinsScaled = b.create( + freqBinsFltType, lowBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value lbDiv = b.create( + freqBinsFltType, lowBinsScaled, twoFiveNineFiveConst); + Value lbClone = b.create( + freqBinsFltType, lowBinsScaled, /*memory_format=*/noneConst); + Value lbTenTensor = b.create( + freqBinsFltType, lbClone, tenConst); + Value lbPow = b.create( + freqBinsFltType, lbTenTensor, lbDiv); + Value lbPowSubOne = b.create( + freqBinsFltType, lbPow, oneConst, /*alpha=*/oneConst); + Value lowBinsHz = b.create( + freqBinsFltType, lbPowSubOne, sevenHConst); + // Normalize freqBinsHz + Value lbMulDft = b.create( + freqBinsFltType, lowBinsHz, dftLenPlusOneItem); + Value lowBinsNormalized = b.create( + freqBinsFltType, lbMulDft, sampleRateItem); + // cast to int32 + Value lowBinsInt = b.create( + freqBinsIntType, lowBinsNormalized, /*dtype=*/int32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value lowBins = b.create( + unsqueezeBinsResType, lowBinsInt, /*dim=*/zeroConst); + + // Center bins mel to hz + Value centerBinsInitInc = b.create( + freqBinsIntType, centerBinsInit, oneConst, /*alpha=*/oneConst); + Value centerBinsMulMelStep = b.create( + freqBinsFltType, centerBinsInitInc, melStep); + Value centerBinsScaled = b.create( + freqBinsFltType, centerBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value cbDiv = b.create( + freqBinsFltType, centerBinsScaled, twoFiveNineFiveConst); + Value cbClone = b.create( + freqBinsFltType, centerBinsScaled, /*memory_format=*/noneConst); + Value cbTenTensor = b.create( + freqBinsFltType, cbClone, tenConst); + Value cbPow = b.create( + freqBinsFltType, cbTenTensor, cbDiv); + Value cbPowSubOne = b.create( + freqBinsFltType, cbPow, oneConst, /*alpha=*/oneConst); + Value centerBinsHz = b.create( + freqBinsFltType, cbPowSubOne, sevenHConst); + // Normalize freqBinsHz + Value cbMulDft = b.create( + freqBinsFltType, centerBinsHz, dftLenPlusOneItem); + Value centerBinsNormalized = b.create( + freqBinsFltType, cbMulDft, sampleRateItem); + // cast to int32 + Value centerBinsInt = b.create( + freqBinsIntType, centerBinsNormalized, /*dtype=*/int32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value centerBins = b.create( + unsqueezeBinsResType, centerBinsInt, /*dim=*/zeroConst); + + // High bins mel to hz + Value highBinsInitInc = b.create( + freqBinsIntType, highBinsInit, twoConst, /*alpha=*/oneConst); + Value highBinsMulMelStep = b.create( + freqBinsFltType, highBinsInitInc, melStep); + Value highBinsScaled = b.create( + freqBinsFltType, highBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value hbDiv = b.create( + freqBinsFltType, highBinsScaled, twoFiveNineFiveConst); + Value hbClone = b.create( + freqBinsFltType, highBinsScaled, /*memory_format=*/noneConst); + Value hbTenTensor = b.create( + freqBinsFltType, hbClone, tenConst); + Value hbPow = b.create( + freqBinsFltType, hbTenTensor, hbDiv); + Value hbPowSubOne = b.create( + freqBinsFltType, hbPow, oneConst, /*alpha=*/oneConst); + Value highBinsHz = b.create( + freqBinsFltType, hbPowSubOne, sevenHConst); + // Normalize freqBinsHz + Value hbMulDft = b.create( + freqBinsFltType, highBinsHz, dftLenPlusOneItem); + Value highBinsNormalized = b.create( + freqBinsFltType, hbMulDft, sampleRateItem); + // cast to int32 + Value highBinsInt = b.create( + freqBinsIntType, highBinsNormalized, /*dtype=*/int32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value highBins = b.create( + unsqueezeBinsResType, highBinsInt, /*dim=*/zeroConst); + + Type iotaInitType = inputIntType.getWithSizesAndDtype(shapeNSB, si32Ty); + Value iotaInit = b.create( + iotaInitType, numSpectrogramBinsItem, + /*dtype=*/int32DTypeConst, + /*layout=*/noneConst, /*device=*/noneConst, + /*pin_memory=*/noneConst); + + Torch::ValueTensorType unsqueezeIotaResType = + Torch::ValueTensorType::get(ctx, shapeNSBx1, si32Ty); + Value iota = b.create( + unsqueezeIotaResType, iotaInit, /*dim=*/oneConst); + + Value lowToCenter = b.create( + unsqueezeBinsResType, centerBins, lowBins, /*alpha=*/oneConst); + Value centerToHigh = b.create( + unsqueezeBinsResType, highBins, centerBins, /*alpha=*/oneConst); + + Value oneConstTensor = Torch::createRank0Tensor( + rewriter, binder.getLoc(), + Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst); + + Type scaledType = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty); + Value upscaleInit = b.create( + unsqueezeBinsResType, oneConstTensor, lowToCenter); + Value upscale = b.create( + scaledType, upscaleInit, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + + Value downscaleInit = b.create( + unsqueezeBinsResType, oneConstTensor, centerToHigh); + Value downscale = b.create( + scaledType, downscaleInit, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + + Torch::ValueTensorType binsDiffType = + Torch::ValueTensorType::get(ctx, shapeNSBxNMB, si32Ty); + Torch::ValueTensorType diffFloatType = + Torch::ValueTensorType::get(ctx, shapeNSBxNMB, f32Ty); + + Value iotaSubLBInt = b.create( + binsDiffType, iota, lowBins, /*alpha=*/oneConst); + Value iotaSubLB = b.create( + diffFloatType, iotaSubLBInt, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value rampUp = + b.create(diffFloatType, iotaSubLB, upscale); + + Value hbSubIotaInt = b.create( + binsDiffType, highBins, iota, /*alpha=*/oneConst); + Value hbSubIota = b.create( + diffFloatType, hbSubIotaInt, /*dtype=*/float32DTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + Value rampDown = b.create(diffFloatType, + hbSubIota, downscale); + + // ramp values + Type iotaCmpBinsType = + inputIntType.getWithSizesAndDtype(shapeNSBxNMB, i1Ty); + + // Iota Cmp Bins + Value iotaGtEqCBins = + b.create(iotaCmpBinsType, iota, centerBins); + Value iotaEqCBins = + b.create(iotaCmpBinsType, iota, centerBins); + Value iotaLtLBins = + b.create(iotaCmpBinsType, iota, lowBins); + Value iotaGtLBins = + b.create(iotaCmpBinsType, iota, highBins); + + // Create output freq ramps Low-Center-High + Type rampInitType = + inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty); + Value rampInit = b.create( + rampInitType, iotaGtEqCBins, rampDown, rampUp); + Value rampInitLt = b.create( + rampInitType, iotaLtLBins, zeroConst, rampInit); + Value rampInitLtGt = b.create( + rampInitType, iotaGtLBins, zeroConst, rampInitLt); + + Type C2HCmpBinsType = + inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty); + Value C2HEqZero = b.create( + C2HCmpBinsType, centerToHigh, zeroConst); + Value cornerCases = b.create( + iotaCmpBinsType, iotaEqCBins, C2HEqZero); + Value rampOutput = b.create( + rampInitType, cornerCases, oneFltConst, rampInitLtGt); + + Value outputDTypeConst = b.create( + rewriter.getType(), + rewriter.getI64IntegerAttr(torchDTypeInt.value())); + Value finalOutput = b.create( + resultType, rampOutput, /*dtype=*/outputDTypeConst, + /*non_blocking=*/falseConst, /*copy=*/falseConst, + /*memory_format=*/noneConst); + + rewriter.replaceOp(binder.op, finalOutput); + return success(); + }); + + patterns.onOp( + "Multinomial", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value self; + int64_t onnxDtype, sampleSize; + + if (binder.tensorOperand(self) || + binder.s64IntegerAttr(onnxDtype, "dtype", 6) || + binder.s64IntegerAttr(sampleSize, "sample_size", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + + if (binder.op->hasAttr("torch.onnx.seed")) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + } + + if (sampleSize <= 0) { + return rewriter.notifyMatchFailure(binder.op, + "unsupported: sample_size <= 0"); + } + + std::optional torchDtype = + onnxDtypeIntToTorchDtypeInt(onnxDtype); + if (!torchDtype.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + + Value torchDtypeIntValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(torchDtype.value())); + Value numSamples = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(sampleSize)); + + // PRG is seeded globally by default + Value none = rewriter.create(binder.getLoc()); + // Sample with replacement by default (no onnx equivalent in arguments) + Value cstTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + + // Torch Multinomial always produces a LongTensor + Torch::ValueTensorType selfType = + cast(self.getType()); + Type int64Dtype = + IntegerType::get(selfType.getContext(), 64, IntegerType::Signed); + int64_t batchSize = selfType.getSizes()[0]; + SmallVector outShapes({batchSize, sampleSize}); + Torch::ValueTensorType multinomialOutputType = + Torch::ValueTensorType::get(selfType.getContext(), outShapes, + int64Dtype); + Value multinomialTensor = rewriter.create( + binder.getLoc(), multinomialOutputType, self, numSamples, cstTrue, + none); + + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, multinomialTensor, torchDtypeIntValue, + cstFalse, cstFalse, none); + + return success(); + }); + patterns.onOp( + "NegativeLogLikelihoodLoss", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value self, target, weight, reduction, ignore_index; + int64_t ignore_index_int; + std::string reduction_str; + + if (binder.tensorOperandAtIndex(self, 0) || + binder.tensorOperandAtIndex(target, 1) || + binder.s64IntegerAttr(ignore_index_int, "ignore_index", -100) || + binder.customOpNameStringAttr(reduction_str, "reduction", "mean") || + binder.tensorResultType(resultType)) { + return failure(); + } + + // optional third tensor argument + if (binder.tensorOperandAtIndex(weight, 2)) { + weight = rewriter.create(binder.getLoc()); + } + + ignore_index = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(ignore_index_int)); + + // convert string reduction attr to standardized integer enum value + int reduction_value = + torch_upstream::get_loss_reduction_enum(reduction_str); + reduction = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(reduction_value)); + + Value nllLoss = rewriter + .create( + binder.getLoc(), resultType, resultType, self, + target, weight, reduction, ignore_index) + ->getResult(0); + + rewriter.replaceOp(binder.op, nllLoss); + return success(); + }); + patterns.onOp( + "NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + auto rawSize = resultType.getSizes(); + SmallVector torchResultSize(rawSize.rbegin(), rawSize.rend()); + auto torchResultType = rewriter.getType( + torchResultSize, resultType.getDtype()); + auto nonZero = rewriter.create( + binder.getLoc(), torchResultType, operand); + // The output tensor has a shape of ((n, z)), where (n) is the + // number of dimensions in the input tensor and (z) is the + // number of non-zero elements2. This is different from + // PyTorch's default behavior, where the dimensions are + // reversed. + rewriter.replaceOpWithNewOp( + binder.op, resultType, nonZero, zero, one); + return success(); + }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return rewriter.notifyMatchFailure(binder.op, "auto_pad bind failure"); - if (autoPad != "NOTSET") - return rewriter.notifyMatchFailure( - binder.op, "unsupported conversion: auto_pad != NOTSET"); - Torch::ValueTensorType resultType; + Torch::ValueTensorType resultTypeOut; Value operand; - bool ceilMode; - int64_t storageOrder; + int64_t ceilMode, storageOrder; // TODO: Add support for indices output and storage_order if (binder.tensorOperand(operand) || - binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || binder.s64IntegerAttr(storageOrder, "storage_order", 0) || - binder.tensorResultType(resultType)) + binder.tensorResultTypeAtIndex(resultTypeOut, 0)) return rewriter.notifyMatchFailure( binder.op, "operand/ceil_mode/storage_order/resultType bind failure"); @@ -512,6 +1174,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); + // set default padding if (padding.empty()) padding.resize(spatial, 0); if (strides.empty()) @@ -519,6 +1182,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (dilations.empty()) dilations.resize(spatial, 1); + auto inputTensorType = cast(operand.getType()); + + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (autoPad != "NOTSET" && autoPad != "VALID") { + const bool isSameLower = autoPad == "SAME_LOWER"; + ArrayRef inputShape = inputTensorType.getSizes(); + padding.resize_for_overwrite(2 * spatial); + for (unsigned dimIdx = 0; dimIdx < spatial; dimIdx++) { + const int64_t dilatedKernelSize = + dilations[dimIdx] * (kernel[dimIdx] - 1) + 1; + int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / + strides[dimIdx] - + 1) * + strides[dimIdx] + + dilatedKernelSize - inputShape[dimIdx + 2]; + totalPad = totalPad >= 0 ? totalPad : 0; + padding[dimIdx] = + isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); + padding[spatial + dimIdx] = totalPad - padding[dimIdx]; + } + } + // If the padding is symmetric we can push the padding operation to the // torch operator. if (padding.size() == static_cast(2 * spatial)) { @@ -537,22 +1228,21 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto operandTy = cast(operand.getType()); llvm::SmallVector shuffledPadding(spatial * 2); llvm::SmallVector paddedShape(operandTy.getSizes()); - shuffledPadding.resize(2 * rank); for (int i = 0; i < spatial; ++i) { paddedShape[i + 2] += padding[i] + padding[i + spatial]; - shuffledPadding[2 * i] = padding[i]; - shuffledPadding[2 * i + 1] = padding[i + spatial]; + shuffledPadding[2 * i] = padding[spatial - i - 1]; + shuffledPadding[2 * i + 1] = padding[2 * spatial - i - 1]; } Value shuffledPaddingList = - createConstantIntList(binder, rewriter, padding); + createConstantIntList(binder, rewriter, shuffledPadding); Value zero; - if (resultType.getDtype().isa()) { + if (isa(resultTypeOut.getDtype())) { zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr( std::numeric_limits::lowest())); - } else if (resultType.getDtype().isa()) { + } else if (isa(resultTypeOut.getDtype())) { zero = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr( std::numeric_limits::lowest())); @@ -575,22 +1265,281 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value cstCeilMode = rewriter.create(binder.getLoc(), ceilMode); - if (rank == 3) + if (binder.op->getNumResults() == 2) { + Torch::ValueTensorType resultTypeIndices; + if (binder.tensorResultTypeAtIndex(resultTypeIndices, 1)) + return failure(); + + if (rank == 3) + return rewriter.notifyMatchFailure( + binder.op, "Unimplemented: AtenMaxPool1dWithIndicesOp"); + + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultTypeOut, resultTypeIndices, operand, + kernelSizeList, stridesList, paddingList, dilationsList, + cstCeilMode); + return success(); + } + if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultTypeOut, resultTypeIndices, operand, + kernelSizeList, stridesList, paddingList, dilationsList, + cstCeilMode); + return success(); + } + } else { + if (rank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultTypeOut, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultTypeOut, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultTypeOut, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + } + return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); + }); + patterns.onOp( + "MaxRoiPool", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallVector pooledShape; + float spatialScale; + if (binder.s64IntegerArrayAttr(pooledShape, "pooled_shape", {}) || + binder.f32FloatAttr(spatialScale, "spatial_scale", 1.0f)) { return rewriter.notifyMatchFailure(binder.op, - "Unimplemented: AtenMaxPool1dOp"); - if (rank == 4) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, dilationsList, cstCeilMode); - return success(); + "Attribute bind failure"); } - if (rank == 5) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, kernelSizeList, stridesList, - paddingList, dilationsList, cstCeilMode); - return success(); + Torch::ValueTensorType resultTy; + Value input, rois; + if (binder.tensorOperands(input, rois) || + binder.tensorResultType(resultTy)) { + return rewriter.notifyMatchFailure(binder.op, + "Operand or result type mismatch"); } - return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); + + Value outputShapeList = + createConstantIntList(binder, rewriter, pooledShape); + Location loc = binder.getLoc(); + + auto inputTy = cast(input.getType()); + auto roisTy = cast(rois.getType()); + if (!inputTy || !inputTy.hasSizes()) + return failure(); + if (!roisTy || !roisTy.hasSizes()) + return failure(); + + auto intTy = rewriter.getIntegerType(64, true); + auto floatTy = roisTy.getDtype(); + auto torchIntTy = rewriter.getType(); + + Value spatialScaleValue = rewriter.create( + loc, rewriter.getF64FloatAttr(spatialScale)); + + Value boolTrue = rewriter.create( + loc, rewriter.getBoolAttr(true)); + + ArrayRef inputShape = inputTy.getSizes(); + int64_t inputRank = inputShape.size(); + if (inputRank < 4) { + return rewriter.notifyMatchFailure( + binder.op, "Rank of input tensor must be >= 4"); + } + + ArrayRef roisShape = roisTy.getSizes(); + if (!roisTy.areAllSizesKnown() || roisShape.size() != 2 || + roisShape[1] != 5) { + return rewriter.notifyMatchFailure( + binder.op, "Expected ROIs to be statically sized tensor of shape " + "(num_rois, 5)"); + } + int64_t numRois = roisShape[0]; + + /* The implementation is based on the following algorithm: + MaxRoiPool ( + input : tensor, rois : tensor) => (output) + { + * Step 1: Extract ROI specification + - Each ROI is represented as [batch_id, x1, y1, x2, y2], where + range is inclusive of x1, y1, x2, and y2 + - The range values are scaled by spatial_scale + + BatchIdxsFloat = Select(rois, dim=1, index=0) + BatchIdxs = CastLong(BatchIdxsFloat) + RoiBBsFloat = Slice(rois, dim=1, start=1, end=5, stride=1) + RoiBBsScaledFloat = MulScalar(RoiBBsFloat, spatial_scale) + RoiBBsScaled = CastLong(RoiBBsScaledFloat) + + * Step 2: Iteratively pool ROIs + pooledROIs = [] + for (roiIdx = 0; roiIdx < len(rois); roiIdx++) { + * Step 2a: For each ROI, we extract batch_id, x1, y1, x2, & y2 + RoiSpec = Select(RoiBBsScaled, 0, roiIdx) : tensor<4xint> + roiValues = [] + for (specIdx = 0; specIdx < 5; specIdx++) { + if (specIdx == 0) + SpecTensor = Select(BatchIdxs, 1, roiIdx) : tensor + else + SpecTensor = Select(RoiSpec, 0, specIdx-1) : tensor + SpecValue = Item(specTensor) : torch.int + roiValues.push(SpecValue) + } + BatchIdx, X1, Y1, X2, Y2 = roiValues + + * Step 2b: extract image from input and extract region + - X2 and Y2 are incremented by 1 to make range inclusive + - width and height dimension are calculated once outside of loop + but intuition is expressed more clearly below + + image = Select(input, 0, BatchIdx) + widthDim = rank(image) - 1 + heightDim = rank(image) - 2 + + imageExtractedY = Slice(image, heightDim, Y1, Y2 + 1, 1) + region = Slice(image, widthDim, X1, X2 + 1, 1) + + * Step 2c: apply adaptive max pooling to pool region of interest + into final pooled size + pooledROI = AdaptiveMaxPool2d(region, pooled_shape) + pooledROIs.push(pooledROI) + } + + * Step 3: Stack pooled regions and return final output + return output = Stack(pooledRois, dim=0) + } + */ + + SmallVector constInts(6); + for (int i = 0; i <= 5; i++) { + constInts[i] = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + } + + int64_t widthDim = inputRank - 2; + Value widthDimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(widthDim)); + + int64_t heightDim = inputRank - 3; + Value heightDimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(heightDim)); + + // extract indices of images within batch + auto batchIdxsShape = SmallVector{Torch::kUnknownSize}; + auto batchIdxsFloatTy = + rewriter.getType(batchIdxsShape, floatTy); + Value batchIdxsFloat = rewriter.create( + loc, batchIdxsFloatTy, rois, constInts[1], constInts[0]); + auto batchIdxsIntTy = + rewriter.getType(batchIdxsShape, intTy); + Value batchIdxs = rewriter.create( + loc, batchIdxsIntTy, batchIdxsFloat, boolTrue); + + // extract scaled ranges for regions of interest + auto roiBBsShape = SmallVector{Torch::kUnknownSize, 4}; + auto roiBBsFloatTy = + rewriter.getType(roiBBsShape, floatTy); + Value roiBBs = rewriter.create( + loc, roiBBsFloatTy, rois, constInts[1], constInts[1], constInts[5], + constInts[1]); + Value roiBBsScaledFloat = rewriter.create( + loc, roiBBsFloatTy, roiBBs, spatialScaleValue); + auto roiBBsTy = + rewriter.getType(roiBBsShape, intTy); + Value roiBBsScaled = rewriter.create( + loc, roiBBsTy, roiBBsScaledFloat, boolTrue); + + SmallVector pooledRois; + + for (int64_t i = 0; i < numRois; i++) { + Value roiIdx = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + + auto roiSpecTy = rewriter.getType( + roiBBsTy.getSizes().slice(1), intTy); + Value roiSpec = rewriter.create( + loc, roiSpecTy, roiBBsScaled, constInts[0], roiIdx); + + // Load individual ROI specification values + SmallVector roiValues(5); + for (int specIdx = 0; specIdx < 5; specIdx++) { + auto intEmptyTensorTy = rewriter.getType( + SmallVector{}, intTy); + Value specTensor; + if (specIdx == 0) { // batch index + specTensor = rewriter.create( + loc, intEmptyTensorTy, batchIdxs, constInts[0], roiIdx); + } else { // roi dimension + specTensor = rewriter.create( + loc, intEmptyTensorTy, roiSpec, constInts[0], + constInts[specIdx - 1]); + } + Value specValue = + rewriter.create(loc, torchIntTy, specTensor); + roiValues[specIdx] = specValue; + } + Value batchIdx = roiValues[0], roiX1 = roiValues[1], + roiY1 = roiValues[2], roiX2 = roiValues[3], + roiY2 = roiValues[4]; + + // add 1 to make range ends inclusive as per ONNX implementation + roiX2 = rewriter.create(loc, torchIntTy, roiX2, + constInts[1]); + roiY2 = rewriter.create(loc, torchIntTy, roiY2, + constInts[1]); + + auto imageTy = rewriter.getType( + inputShape.slice(1), inputTy.getDtype()); + Value image = rewriter.create( + loc, imageTy, input, constInts[0], batchIdx); // (NC x H x W) + + SmallVector imageUnknownShape(imageTy.getSizes()); + imageUnknownShape[heightDim] = Torch::kUnknownSize; + imageUnknownShape[widthDim] = Torch::kUnknownSize; + auto imageUnknownTy = rewriter.getType( + imageUnknownShape, imageTy.getDtype()); + + // extract ROI from image + Value imageExtractedY = rewriter.create( + loc, imageUnknownTy, image, heightDimValue, roiY1, roiY2, + constInts[1]); + Value region = rewriter.create( + loc, imageUnknownTy, imageExtractedY, widthDimValue, roiX1, roiX2, + constInts[1]); + + SmallVector pooledRegionShape(imageTy.getSizes()); + pooledRegionShape[heightDim] = pooledShape[0]; + pooledRegionShape[widthDim] = pooledShape[1]; + auto pooledRegionTy = rewriter.getType( + pooledRegionShape, imageTy.getDtype()); + auto pooledRegionIndicesTy = rewriter.getType( + pooledRegionShape, intTy); + + // apply pooling on ROI + Value pooledRegion = + rewriter + .create( + loc, pooledRegionTy, pooledRegionIndicesTy, region, + outputShapeList) + .getResult0(); + pooledRois.push_back(pooledRegion); + } + + Value pooledRoisList = rewriter.create( + loc, Torch::ListType::get(pooledRois[0].getType()), pooledRois); + rewriter.replaceOpWithNewOp( + binder.op, resultTy, pooledRoisList, constInts[0]); + + return success(); }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -933,10 +1882,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( flattenedIndices = rewriter.create( loc, flattenIndicesTy, reshapedIndices, constZero); } else if (indicesRank > 1) { - Value endDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 2)); - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, endDim); + if (batchDimCount > indicesRank - 2) { + flattenedIndices = rewriter.create( + loc, flattenIndicesTy, reshapedIndices, batchDimCountVal); + } else { + Value endDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(indicesRank - 2)); + flattenedIndices = rewriter.create( + loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, + endDim); + } } // step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices. @@ -958,8 +1913,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value endDim = rewriter.create( loc, rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1)); - Value flattenedData = rewriter.create( - loc, flattenDataTy, data, batchDimCountVal, endDim); + Value flattenedData = data; + + if (indicesLastDim != 1) { + flattenedData = rewriter.create( + loc, flattenDataTy, data, batchDimCountVal, endDim); + } // step 10. Now we have flattenedData and expandedIndices of same rank // to perform gather operation. @@ -975,6 +1934,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, gather, /*dim=*/constZero); return success(); } + + if (unflattenIndicesDims.empty()) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, /*dim=*/batchDimCountVal); + return success(); + } + Value unflattenSizeList = rewriter.create( loc, intListTy, unflattenIndicesDims); rewriter.replaceOpWithNewOp( @@ -1042,7 +2008,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( indicesCt = Torch::kUnknownSize; break; } - indicesCt *= sz; } @@ -1077,8 +2042,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); } - rewriter.replaceOpWithNewOp(binder.op, resultType, - gather); + // indicesRank = 0 will select 1 from the axis dim and squeeze it + // Use AtenSqueezeDimOp for the case of result with dynamic shape + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, index); return success(); }); patterns.onOp( @@ -1095,6 +2062,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value constAxis = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + + auto indicesTy = cast(indices.getType()); + Value constZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value axisSize = rewriter.create(binder.getLoc(), + data, constAxis); + Value indicesAdd = rewriter.create( + binder.getLoc(), indicesTy, indices, axisSize, constOne); + + auto boolTy = rewriter.getType( + indicesTy.getSizes(), rewriter.getI1Type()); + Value lt = rewriter.create( + binder.getLoc(), boolTy, indices, constZero); + indices = rewriter.create( + binder.getLoc(), indicesTy, lt, indicesAdd, indices); + Value sparseGrad = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); @@ -1126,11 +2111,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto transpose = [&](Value m) -> Value { auto tty = cast(m.getType()); - auto shape = tty.getOptionalSizes(); + std::optional> shape = tty.getOptionalSizes(); + llvm::SmallVector newShape; if (shape.has_value()) { - llvm::SmallVector newShape(shape.value()); + newShape.append(shape.value().begin(), shape.value().end()); std::reverse(newShape.begin(), newShape.end()); - shape = std::move(newShape); + shape = newShape; } auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, tty.getOptionalDtype()); @@ -1266,70 +2252,401 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); }); patterns.onOp( - "LayerNormalization", 17, + "GlobalMaxPool", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType yType, meanType, invStdDevType; - Value x, scale, b; - int64_t axis, stashType; - float epsilon; - if (binder.tensorOperandAtIndex(x, 0) || - binder.tensorOperandAtIndex(scale, 1) || - binder.tensorOperandAtIndex(b, 2) || - binder.tensorResultTypeAtIndex(yType, 0) || - binder.s64IntegerAttr(axis, "axis", -1) || - binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || - binder.s64IntegerAttr(stashType, "stash_type", 1)) + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) return failure(); - Value constEpsilon = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getF64FloatAttr(epsilon)); - unsigned rank = 1; - if (std::optional maybeRank = Torch::getTensorRank(x)) - rank = *maybeRank; - SmallVector normalized; - axis = Torch::toPositiveDim(axis, rank); - auto xType = cast(x.getType()); - if (!xType.hasSizes()) { + + auto inputTensorType = cast(operand.getType()); + if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure( - binder.op, "Expected input (X) to have sizes"); + binder.op, "Expected input type having sizes"); } - ArrayRef xShape = xType.getSizes(); - for (int64_t n = axis; n < rank; n++) { - normalized.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); + ArrayRef inputShape = inputTensorType.getSizes(); + unsigned inputRank = inputShape.size(); + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); } - Value normalized_shape = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + SmallVector cstKernel, cstPadding, cstStrides, cstDilations; + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + for (unsigned i = 2; i < inputRank; i++) { + if (inputShape[i] == Torch::kUnknownSize) { + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = rewriter.create( + binder.getLoc(), operand, dim); + cstKernel.push_back(inputDimSize); + } else { + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[i]))); + } + cstPadding.push_back(cstZero); + cstDilations.push_back(cstOne); + cstStrides.push_back(cstOne); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstCeilMode = + rewriter.create(binder.getLoc(), false); + + if (inputRank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } else if (inputRank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } else if (inputRank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + return failure(); + }); + patterns.onOp( + "GlobalLpPool", 2, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t p; + if (binder.tensorOperand(operand) || binder.s64IntegerAttr(p, "p", 2) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTensorType = cast(operand.getType()); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + ArrayRef inputShape = inputTensorType.getSizes(); + unsigned inputRank = inputShape.size(); + // only handle 2D, 3D and 5D pooling cases + if (inputRank > 5 || inputRank < 3) { + return failure(); + } + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } + ArrayRef resultShape = resultType.getSizes(); + + SmallVector cstKernel, cstPadding, cstStrides; + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value numElements = cstOne; + for (unsigned i = 2; i < inputRank; i++) { + if (inputShape[i] == Torch::kUnknownSize) { + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = rewriter.create( + binder.getLoc(), operand, dim); + cstKernel.push_back(inputDimSize); + } else { + int64_t kernelSize = inputShape[i] - resultShape[i] + 1; + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + } + numElements = rewriter.create( + binder.getLoc(), rewriter.getType(), + cstKernel.back(), numElements); + cstPadding.push_back(cstZero); + cstStrides.push_back(cstOne); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value cstCeilMode = cstFalse; + Value cstCountIncludePad = cstFalse; + Value abs = rewriter.create(binder.getLoc(), + inputTensorType, operand); + Value pv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), p)); + Value pow = rewriter.create( + binder.getLoc(), inputTensorType, abs, pv); + Value avgPool; + if (inputRank == 3) { + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + avgPool = rewriter.create( + binder.getLoc(), resultType, avgPool, numElements); + } else if (inputRank == 4) { + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstOne); + } else { // inputRank == 5 + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstOne); + } + Value invP = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(double{1.0 / p})); + rewriter.replaceOpWithNewOp( + binder.op, resultType, avgPool, invP); + return success(); + }); + + patterns.onOp( + "LpPool", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + + Torch::ValueTensorType resultType; + Value operand; + int64_t ceilMode, p; + if (binder.tensorOperand(operand) || + binder.s64IntegerAttr(ceilMode, "ceil_mode", 0) || + binder.s64IntegerAttr(p, "p", 2) || + binder.tensorResultType(resultType)) + return failure(); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + // only 1D, 2D and 3D LpPool is supported. + if (rank > 5 || rank < 3) { + return failure(); + } + + SmallVector kernel, padding, strides, dilations; + SmallVector defaultPadding(2 * (rank - 2), 0); + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {}) || + binder.s64IntegerArrayAttr(padding, "pads", defaultPadding) || + binder.s64IntegerArrayAttr( + strides, "strides", llvm::SmallVector(rank - 2, 1)) || + binder.s64IntegerArrayAttr(dilations, "dilations", {})) { + return failure(); + } + if (kernel.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + } + if (padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, + "padding list size does not match twice the number of axes"); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + if (dilations.size() > 0) { + return rewriter.notifyMatchFailure( + binder.op, "dilation is not supported by torch.aten.avgpool op " + "and therefore it is not supported for LpPool."); + } + + SmallVector cstKernel, cstPadding, cstStrides; + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value numElements = cstOne; + for (int64_t i : kernel) { + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + numElements = rewriter.create( + binder.getLoc(), rewriter.getType(), + cstKernel.back(), numElements); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + // onnx lp pool doesn't have countIncludePad attribute but set it to + // true so that in 1D case numElements is correctly undoes divison. For + // 2D/3D case, division is avoided by divison_override. + Value cstCountIncludePad = + rewriter.create(binder.getLoc(), true); + Value pv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), p)); + auto inputTensorType = cast(operand.getType()); + Value abs = rewriter.create(binder.getLoc(), + inputTensorType, operand); + Value pow = rewriter.create( + binder.getLoc(), inputTensorType, abs, pv); + Value avgPool; + if (rank == 3) { + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + avgPool = rewriter.create( + binder.getLoc(), resultType, avgPool, numElements); + } else if (rank == 4) { + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstOne); + } else { // rank == 5 + avgPool = rewriter.create( + binder.getLoc(), resultType, pow, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstOne); + } + Value invP = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(double{1.0 / p})); + rewriter.replaceOpWithNewOp( + binder.op, resultType, avgPool, invP); + return success(); + }); + + patterns.onOp( + "LayerNormalization", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType yType, meanType, invStdDevType; + Value x, scale, b; + int64_t axis, stashType; + float epsilon; + if (binder.tensorOperandAtIndex(x, 0) || + binder.tensorOperandAtIndex(scale, 1) || + binder.tensorOperandAtIndex(b, 2) || + binder.tensorResultTypeAtIndex(yType, 0) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || + binder.s64IntegerAttr(stashType, "stash_type", 1)) + return failure(); + + std::optional stashTypeIntTorch = + onnxDtypeIntToTorchDtypeInt(stashType); + if (!stashTypeIntTorch.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given stash_type"); + FailureOr stashDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + (torch_upstream::ScalarType)stashTypeIntTorch.value()); + if (failed(stashDtype)) + return failure(); + + // Convert dtype if stash_type is different from input dtype + auto xType = cast(x.getType()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value none = rewriter.create(binder.getLoc()); + if (*stashDtype != xType.getOptionalDtype()) { + auto newXType = + xType.getWithSizesAndDtype(xType.getOptionalSizes(), *stashDtype); + Value dtypeValue = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(stashTypeIntTorch.value())); + x = rewriter.create( + binder.getLoc(), newXType, x, /*dtype=*/dtypeValue, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + + Value constEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(epsilon)); + unsigned rank = 1; + if (std::optional maybeRank = Torch::getTensorRank(x)) + rank = *maybeRank; + SmallVector normalized; + axis = Torch::toPositiveDim(axis, rank); + if (!xType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input (X) to have sizes"); + } + ArrayRef xShape = xType.getSizes(); + for (int64_t n = axis; n < rank; n++) { + normalized.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); + } + Value normalized_shape = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); + SmallVector reducedShape(rank, 1); + for (int64_t i = 0; i < axis; i++) + reducedShape[i] = xShape[i]; + auto reducedType = + xType.getWithSizesAndDtype(reducedShape, *stashDtype); + auto y = rewriter.create( + binder.getLoc(), yType, /*meanType=*/reducedType, + /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, + constEpsilon); + int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { - SmallVector reducedShape(rank, 1); - for (int64_t i = 0; i < axis; i++) - reducedShape[i] = xShape[i]; - auto reducedType = xType.getWithSizesAndDtype( - reducedShape, xType.getOptionalDtype()); - Value y = rewriter - .create( - binder.getLoc(), yType, /*meanType=*/reducedType, - /*invStdDevType=*/reducedType, x, normalized_shape, - scale, b, constEpsilon) - .getResult0(); - rewriter.replaceOp(binder.op, y); + rewriter.replaceOp(binder.op, y.getResult0()); return success(); } - if (numResults == 3) { - if (binder.tensorResultTypeAtIndex(meanType, 1) || - binder.tensorResultTypeAtIndex(invStdDevType, 2)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, yType, meanType, invStdDevType, x, normalized_shape, - scale, b, constEpsilon); - return success(); + + Value meanOutput = y.getResult1(); + Value varOutput = y.getResult2(); + // Convert meanType and varType back if stash_dtype is different + if (binder.tensorResultTypeAtIndex(meanType, 1) || + binder.tensorResultTypeAtIndex(invStdDevType, 2)) + return failure(); + if (*stashDtype != meanType.getOptionalDtype()) { + Value constDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), meanType.getDtype()); + meanOutput = rewriter.create( + binder.getLoc(), meanType, meanOutput, /*dtype=*/constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + varOutput = rewriter.create( + binder.getLoc(), invStdDevType, varOutput, /*dtype=*/constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); } - return rewriter.notifyMatchFailure( - binder.op, "Unimplemented: expected either 1 or 3 results"); + rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput}); + + return success(); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1347,53 +2664,209 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand, constAlpha); return success(); }); + patterns.onOp( + "LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t size; + float alpha, beta, bias; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(size, "size", 2) || + binder.f32FloatAttr(alpha, "alpha", 0.0001f) || + binder.f32FloatAttr(beta, "beta", 0.75f) || + binder.f32FloatAttr(bias, "bias", 1.0f)) + return failure(); + Type dtype = resultType.getOptionalDtype(); + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + Value constBias = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(bias)); + // Please refer to the operator description + // for more info on the lowering + // https://onnx.ai/onnx/operators/onnx__LRN.html + + // squared = operand^2 + Location loc = binder.getLoc(); + Torch::ValueTensorType inTy = + cast(operand.getType()); + Value sqOperand = rewriter.create( + loc, inTy, operand, operand); + // view it as n x 1 x c x d0 x d.. + if (!inTy.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input to have sizes"); + } + ArrayRef inTyShape = inTy.getSizes(); + if (inTyShape.size() < 3) { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: the input dimensions should be >= 3"); + } + if (inTyShape[1] == Torch::kUnknownSize) { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: the second dimension size must be " + "statically known"); + } + SmallVector viewShapeInt{inTyShape[0], 1, inTyShape[1], + inTyShape[2], Torch::kUnknownSize}; + Torch::ValueTensorType reshapeType = + rewriter.getType(viewShapeInt, dtype); + Value viewShapeListVal = + createConstantIntList(binder, rewriter, viewShapeInt); + auto view = rewriter.create( + loc, reshapeType, sqOperand, viewShapeListVal); + // padding + int64_t highPad = (size - 1) / 2; + int64_t lowPad = (size - 1) - highPad; + SmallVector paddingInt{0, 0, 0, 0, lowPad, highPad}; + auto constPadVal = rewriter.create( + loc, rewriter.getType(), + rewriter.getF64FloatAttr(0.0)); + Value paddingListVal = + createConstantIntList(binder, rewriter, paddingInt); + SmallVector paddedShapeInt = viewShapeInt; + paddedShapeInt[2] += size - 1; + Torch::ValueTensorType paddedType = + rewriter.getType(paddedShapeInt, dtype); + auto padded = rewriter.create( + loc, paddedType, view, paddingListVal, constPadVal); + // avg_pool3d + SmallVector kernelSize{size, 1, 1}; + Value kernelSizeList = + createConstantIntList(binder, rewriter, kernelSize); + SmallVector strides{1, 1, 1}; + Value stridesList = createConstantIntList(binder, rewriter, strides); + SmallVector padding{0, 0, 0}; + Value paddingList = createConstantIntList(binder, rewriter, padding); + auto cstCeilMode = + rewriter.create(binder.getLoc(), false); + auto cstCountIncludeMode = + rewriter.create(binder.getLoc(), true); + Value cstNone = rewriter.create(binder.getLoc()); + // Output of pooling is same reshape(view) type because + // of the padding done on the dimensions being pooled. + auto pool = rewriter.create( + loc, reshapeType, padded, kernelSizeList, stridesList, paddingList, + cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone); + // squeeze + auto one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector squeezeShapeInt{ + viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]}; + Torch::ValueTensorType squeezeType = + rewriter.getType(squeezeShapeInt, dtype); + auto squeeze = rewriter.create( + loc, squeezeType, pool, one); + // view as input Type + Value intTyShapeList = + createConstantIntList(binder, rewriter, inTyShape); + auto viewAsInput = rewriter.create( + loc, inTy, squeeze, intTyShapeList); + // mul + add + pow + div + auto mul = rewriter.create( + loc, resultType, viewAsInput, constAlpha); + auto add = rewriter.create(loc, resultType, mul, + constBias, one); + auto pow = rewriter.create( + loc, resultType, add, constBeta); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, pow); + return success(); + }); patterns.onOp( "Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, pads, axes; std::string mode; - // TODO: The `axes` parameter is not supported yet. - if (!binder.tensorOperandAtIndex(axes, 3)) { - return rewriter.notifyMatchFailure( - binder.op, "The axes parameter is not supported yet"); - } if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorOperandAtIndex(pads, 1) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); + + (void)binder.tensorOperandAtIndex(axes, 3); + + bool cstMode = (mode == "constant"); + + // get input rank + auto dataOpTy = cast(data.getType()); + TensorType dataTensor = dataOpTy.toBuiltinTensor(); + if (!dataTensor || !dataTensor.hasRank()) + return rewriter.notifyMatchFailure( + binder.op, "pad length unknown and data operand unranked"); + int64_t dataRank = dataTensor.getRank(); + int64_t padsSize = 2 * dataRank; + Location loc = binder.getLoc(); - // Get pads shape and rank. The pads tensor is expected to be 1-D - // tensor. - auto padsTensorType = cast(pads.getType()); - if (!padsTensorType || !padsTensorType.hasSizes()) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non empty pad tensor"); - } - ArrayRef padsShape = padsTensorType.getSizes(); - int64_t padsRank = padsShape.size(); - if (padsRank != 1) - return rewriter.notifyMatchFailure(binder.op, - "expect 1-d pad tensor"); - - int64_t padsSize = padsShape[0]; - if (padsSize == Torch::kUnknownSize) { - // As per onnx.Pad documentation, padSize = 2*num_data_axes - // (if axes param not passed). Need to be updated when adding - // support for `axes` param. - auto dataOpTy = cast(data.getType()); - TensorType dataTensor = dataOpTy.toBuiltinTensor(); - if (!dataTensor || !dataTensor.hasRank()) - return rewriter.notifyMatchFailure( - binder.op, "pad length unknown and data operand unranked"); - int64_t dataRank = dataTensor.getRank(); - padsSize = 2 * dataRank; + // get pads (earlier versions use an attribute, newer versions use a + // tensor input) + SmallVector padsTensorValue; + if (binder.tensorOperandAtIndex(pads, 1)) { + SmallVector defaultPads(2 * dataRank, 0); + SmallVector padInts; + if (binder.s64IntegerArrayAttr(padInts, "pads", defaultPads)) + return rewriter.notifyMatchFailure(binder.op, + "pads binder failure"); + // opset_version 1 uses the attribute name "paddings" + if (padInts == defaultPads) { + SmallVector paddingsInts; + if (binder.s64IntegerArrayAttr(paddingsInts, "paddings", + defaultPads)) + return rewriter.notifyMatchFailure(binder.op, + "paddings binder failure"); + padInts = paddingsInts; + } + for (auto p : padInts) + padsTensorValue.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(p))); + } else { + // Get pads shape and rank. The pads tensor is expected to be 1-D + // tensor. + auto padsTensorType = cast(pads.getType()); + if (!padsTensorType || !padsTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty pad tensor"); + } + ArrayRef padsShape = padsTensorType.getSizes(); + int64_t padsRank = padsShape.size(); + if (padsRank != 1) + return rewriter.notifyMatchFailure(binder.op, + "expect 1-d pad tensor"); + if (padsShape[0] != Torch::kUnknownSize) { + // As per onnx.Pad documentation, padSize = 2*num_data_axes + // (if axes param not passed). Need to be updated when adding + // support for `axes` param. + padsSize = padsShape[0]; + } + + // Extract all the values of 1-D pad tensor and create a list of all + // these values as torch.pad op expects pad list. + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + SmallVector emptyShape; + Type padsElemType = Torch::ValueTensorType::get( + padsTensorType.getContext(), emptyShape, + padsTensorType.getOptionalDtype()); + for (uint32_t i = 0; i < padsSize; ++i) { + Value index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + auto select = rewriter.create( + loc, padsElemType, pads, constZero, index); + Value selectInt = rewriter.create( + loc, rewriter.getType(), select); + padsTensorValue.push_back(selectInt); + } } Value constantValue; - if (binder.getNumOperands() >= 3) { + if (binder.getNumOperands() >= 3 && cstMode) { if (!binder.tensorOperandAtIndex(constantValue, 2)) { auto constTy = dyn_cast(constantValue.getType()); @@ -1409,46 +2882,123 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } } - if (!constantValue) { + if (!constantValue && cstMode) { auto dataTensorType = cast(data.getType()); - if (dataTensorType.getDtype().isa()) + if (isa(dataTensorType.getDtype())) constantValue = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - if (dataTensorType.getDtype().isa()) + // Earlier versions used a FLOAT attribute to store the constant + // value. The following will pick up on any non-default value attr if + // provided. + float constantFloat; + if (isa(dataTensorType.getDtype()) && + !binder.f32FloatAttr(constantFloat, "value", 0.0f)) constantValue = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0f)); + loc, rewriter.getF64FloatAttr(constantFloat)); if (!constantValue) return rewriter.notifyMatchFailure( binder.op, "expected integer or float data tensor"); } - // Extract all the values of 1-D pad tensor and create a list of all - // these values as torch.pad op expects pad list. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - SmallVector padsTensorValue; - SmallVector emptyShape; - Type padsElemType = - Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, - padsTensorType.getOptionalDtype()); - for (uint32_t i = 0; i < padsSize; ++i) { - Value index = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - auto select = rewriter.create( - loc, padsElemType, pads, constZero, index); - Value selectInt = rewriter.create( - loc, rewriter.getType(), select); - padsTensorValue.push_back(selectInt); + // for modes other than "constant" a value is not required + if (!cstMode) + constantValue = rewriter.create(loc); + + llvm::SmallVector begins; + llvm::SmallVector ends; + for (uint32_t i = 0; i < padsSize / 2; ++i) + begins.push_back(padsTensorValue[i]); + for (uint32_t i = padsSize / 2; i < padsSize; ++i) + ends.push_back(padsTensorValue[i]); + + // If we have the axes we need to compute the appropriate pads: + if (axes) { + auto axesTy = cast(axes.getType()); + assert(axesTy.getSizes().size() == 1); + assert(axesTy.getSizes()[0] != Torch::kUnknownSize); + + auto dataTensorType = cast(data.getType()); + int64_t rank = dataTensorType.getSizes().size(); + auto boolTy = rewriter.getType(); + auto intTy = rewriter.getType(); + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + // Extract the values: + int64_t numAxes = axesTy.getSizes()[0]; + Type axesElemType = Torch::ValueTensorType::get( + axesTy.getContext(), ArrayRef{}, + axesTy.getOptionalDtype()); + llvm::SmallVector axesExtracted; + Value rankV = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank)); + for (uint32_t i = 0; i < numAxes; ++i) { + Value index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + auto select = rewriter.create( + loc, axesElemType, axes, constZero, index); + Value selectInt = rewriter.create( + loc, rewriter.getType(), select); + + Value negAxis = rewriter.create( + loc, boolTy, selectInt, constZero); + negAxis = + rewriter.create(loc, intTy, negAxis); + Value axis = rewriter.create(loc, intTy, + negAxis, rankV); + axis = rewriter.create(loc, intTy, axis, + selectInt); + axesExtracted.push_back(axis); + } + + llvm::SmallVector newBegins; + llvm::SmallVector newEnds; + + for (int j = 0; j < rank; ++j) { + Value newBegin = constZero; + Value newEnd = constZero; + Value iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(j)); + + for (size_t i = 0; i < axesExtracted.size(); ++i) { + Value begin = begins[i]; + Value end = ends[i]; + + Value sameAxis = rewriter.create( + loc, boolTy, axesExtracted[i], iv); + sameAxis = + rewriter.create(loc, intTy, sameAxis); + + begin = rewriter.create(loc, intTy, sameAxis, + begin); + end = rewriter.create(loc, intTy, sameAxis, + end); + + newBegin = rewriter.create(loc, intTy, + newBegin, begin); + newEnd = + rewriter.create(loc, intTy, newEnd, end); + } + + newBegins.push_back(newBegin); + newEnds.push_back(newEnd); + } + + begins = std::move(newBegins); + ends = std::move(newEnds); } // The torch.pad op expects a different arrangement of padding pairs for - // each dimension as compared to the onnx.pad op. So, rearranging pad - // tensor to satisfy torch.pad op semantics. + // each dimension as compared to the onnx.pad op. Rearrange the pad + // tensor as shown below: + // + // [x1_begin, x2_begin, ..., x1_end, x2_end,...] -> + // [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end] SmallVector padsRearrange; - for (uint32_t i = 0; i < padsSize / 2; i++) { - padsRearrange.emplace_back(padsTensorValue[i]); - padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]); + for (int32_t i = begins.size() - 1; i >= 0; i--) { + padsRearrange.emplace_back(begins[i]); + padsRearrange.emplace_back(ends[i]); } Value padsSizeList = @@ -1458,6 +3008,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ListType::get(rewriter.getType()), padsRearrange) .getResult(); + + // lowering to AtenConstantPadNdOp directly allows passing any torch + // scalar type for the value, whereas AtenPadOp takes an optional float + // type. + if (cstMode && !isa(constantValue.getType())) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, padsSizeList, constantValue); + return success(); + } + + // translate a few mismatching mode names ONNX -> Torch + mode = (mode == "edge") ? "replicate" : mode; + mode = (mode == "wrap") ? "circular" : mode; + Value modeVal = rewriter.create( loc, rewriter.getStringAttr(mode)); @@ -1465,20 +3029,50 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, data, padsSizeList, modeVal, constantValue); return success(); }); - patterns.onOp("Pow", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) { - return failure(); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); patterns.onOp( - "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ONNX specifies that the result types matches the type of lhs. + // In torch, the result type is integer when both operands are integer, + // and otherwise operand types are promoted to f64. + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + + auto loc = binder.getLoc(); + Value cstFalse = rewriter.create( + loc, rewriter.getBoolAttr(false)); + Value none = rewriter.create(loc); + + auto powType = resultType; + if (isa(resultType.getDtype())) { + powType = rewriter.getType( + resultType.getSizes(), rewriter.getF64Type()); + } + + Value pow = rewriter.create(loc, powType, + lhs, rhs); + + if (!isa(resultType.getDtype())) { + rewriter.replaceOp(binder.op, pow); + return success(); + } + + auto outDtype = Torch::getScalarTypeForType(resultType.getDtype()); + auto outTyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(outDtype))); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, pow, outTyConst, cstFalse, cstFalse, none); + + return success(); + }); + patterns.onOp( + "Identity", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensor; if (binder.tensorOperand(tensor) || @@ -1663,23 +3257,32 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( depth = rewriter.create( loc, rewriter.getType(), depth); - auto selectTy = rewriter.getType( - llvm::SmallVector{1}, valuesTy.getDtype()); - + Type boolTy = rewriter.getType( + indicesTy.getSizes(), rewriter.getI1Type()); Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); Value one = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); + Value lt = + rewriter.create(loc, boolTy, indices, zero); + Value add = rewriter.create( + loc, indicesTy, indices, depth, one); + indices = rewriter.create(loc, indicesTy, lt, + add, indices); + + auto selectTy = rewriter.getType( + llvm::SmallVector{1}, valuesTy.getDtype()); + + bool valuesAreInt = isa(valuesTy.getDtype()); + Type valuesETy = valuesAreInt ? intTy : floatTy; Value off = rewriter.create(loc, selectTy, values, zero, zero); - off = rewriter.create( - loc, rewriter.getType(), off); + off = rewriter.create(loc, valuesETy, off); Value on = rewriter.create(loc, selectTy, values, zero, one); - on = rewriter.create( - loc, rewriter.getType(), on); + on = rewriter.create(loc, valuesETy, on); auto i32Ty = rewriter.getIntegerType(32, true); llvm::SmallVector onehotShape(indicesTy.getSizes()); @@ -1690,7 +3293,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value onehot = rewriter.create( binder.getLoc(), onehotTy, indices, depth); - for (int i = valuesTy.getSizes().size(); i > axis; ++i) { + for (int i = indicesTy.getSizes().size(); i > axis; --i) { std::swap(onehotShape[i - 1], onehotShape[i]); Value iv0 = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); @@ -1719,9 +3322,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); - onehotTy = rewriter.getType( - onehotShape, resultType.getDtype()); - onehot = rewriter.create(loc, onehotTy, + onehot = rewriter.create(loc, resultType, onehot, on, off); rewriter.replaceOp(binder.op, onehot); @@ -1739,4 +3340,557 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, input); return success(); }); + + patterns.onOp( + "Hardmax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // onnx.Hardmax can be expanded into the following python code: + // + // import torch.nn.functional as F + // def hardmax(tensor, dim=-1): + // maximums = torch.argmax(tensor, dim=dim, keepdim=False) + // return F.one_hot(maximums) + // + // Given an example input: + // tensor([[1, 2, 3], + // [4, 6, 5], + // [9, 8, 7]]) + // Above code yields the following: + // tensor([[0, 0, 1], + // [0, 1, 0], + // [1, 0, 0]]) + + Torch::ValueTensorType resultType; + int64_t axisValue; + Value input, axis; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(axisValue, "axis", -1) || + binder.tensorResultType(resultType)) + return failure(); + + auto loc = binder.getLoc(); + auto inputTy = cast(input.getType()); + + if (axisValue < 0) + axisValue += inputTy.getSizes().size(); + + axis = rewriter.create( + loc, rewriter.getI64IntegerAttr(axisValue)); + + // torch.argmax + Value constKeepDims = rewriter.create( + loc, rewriter.getType(), + rewriter.getBoolAttr(false)); + + SmallVector argmaxShape; + for (int i = 0, s = inputTy.getSizes().size(); i < s; ++i) { + if (i == axisValue) + continue; + argmaxShape.push_back(inputTy.getSizes()[i]); + } + + auto argmaxTy = rewriter.getType( + argmaxShape, rewriter.getIntegerType(32, IntegerType::Signed)); + Value argmax = rewriter.create( + loc, argmaxTy, input, axis, constKeepDims); + + // one_hot + SmallVector onehotShape(argmaxShape); + onehotShape.push_back(inputTy.getSizes()[axisValue]); + auto onehotTy = rewriter.getType( + onehotShape, resultType.getDtype()); + Value numClasses = + rewriter.create(binder.getLoc(), input, axis); + Value onehot = rewriter.create( + binder.getLoc(), onehotTy, argmax, numClasses); + + SmallVector permutation; + for (int i = 0; i < axisValue; ++i) + permutation.push_back(i); + permutation.push_back(onehotShape.size() - 1); + for (int i = axisValue, s = onehotShape.size(); i < s - 1; ++i) + permutation.push_back(i); + + SmallVector permValues; + for (auto d : permutation) { + permValues.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(d))); + } + + Value permuteDims = rewriter.create( + loc, Torch::ListType::get(rewriter.getType()), + permValues); + rewriter.replaceOpWithNewOp(binder.op, resultType, + onehot, permuteDims); + return success(); + }); + patterns.onOp("LpNormalization", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t axis, p; + Value input; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.s64IntegerAttr(p, "p", 2) || + binder.tensorResultType(resultType)) + return failure(); + + auto loc = binder.getLoc(); + Value cstAxis = rewriter.create( + loc, rewriter.getI64IntegerAttr(axis)); + Value cstP = rewriter.create( + loc, rewriter.getI64IntegerAttr(p)); + Value cstKeepDim = rewriter.create( + loc, rewriter.getBoolAttr(true)); + Value axisPrimList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + llvm::ArrayRef{cstAxis}); + + SmallVector normSizes(resultType.getSizes()); + int64_t rank = normSizes.size(); + axis = axis % rank; + axis = (axis < 0) ? axis + rank : axis; + normSizes[axis] = 1; + auto normType = rewriter.getType( + normSizes, resultType.getDtype()); + Value norm = rewriter.create( + loc, normType, input, cstP, axisPrimList, cstKeepDim); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, norm); + return success(); + }); + patterns.onOp( + "MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // TODO: Add support for `output_shape` arg. + if (binder.op->getNumOperands() == 3) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: output_shape arg is not supported"); + + Torch::ValueTensorType resultType; + Value data, indices; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "data/indices/resultType bind failure"); + std::optional maybeRank = Torch::getTensorRank(data); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + int64_t rank = *maybeRank; + int64_t spatial = rank - 2; + + if (rank <= 3 || rank > 5) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: MaxUnpool support " + "only present for rank 4/5 input"); + + if (!(resultType.hasSizes() && resultType.areAllSizesKnown())) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected result to have all shapes " + "statically known"); + + SmallVector resultShape(resultType.getSizes()); + Value resultShapeList = + createConstantIntList(binder, rewriter, resultShape); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList); + return success(); + } + + SmallVector padding, strides; + if (binder.s64IntegerArrayAttr(padding, "pads", {})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (!padding.empty() && + padding.size() != static_cast(2 * spatial)) + return rewriter.notifyMatchFailure( + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); + if (binder.s64IntegerArrayAttr(strides, "strides", {})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (!strides.empty() && strides.size() != static_cast(spatial)) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + + if (padding.empty()) + padding.resize(spatial, 0); + if (strides.empty()) + strides.resize(spatial, 1); + + // If the padding is symmetric we can push the padding + // operation to the torch operator. + if (padding.size() == static_cast(2 * spatial)) { + bool equal = true; + for (int i = 0; i < spatial; ++i) { + equal = equal && (padding[i] == padding[i + spatial]); + } + if (equal) + padding.resize(spatial); + } + + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList, stridesList, + paddingList); + return success(); + }); + patterns.onOp( + "GroupNormalization", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, scale, bias; + int64_t numGroups, stashType; + float epsilon; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(scale, 1) || + binder.tensorOperandAtIndex(bias, 2) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(numGroups, "num_groups") || + binder.f32FloatAttr(epsilon, "epsilon", 1e-5) || + binder.s64IntegerAttr(stashType, "stash_type", 1)) + return failure(); + + // Since the support for `stash_type` arg does not exist in + // the torch op so we just check for the stash_type to be same + // as the input dtype since that won't require us to do any + // input type conversion and hence can be supported. + std::optional stashTypeIntTorch = + onnxDtypeIntToTorchDtypeInt(stashType); + if (!stashTypeIntTorch.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for the given stash_type"); + + FailureOr stashDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + (torch_upstream::ScalarType)stashTypeIntTorch.value()); + if (failed(stashDtype)) + return failure(); + auto inputDtype = + cast(input.getType()).getOptionalDtype(); + if (*stashDtype != inputDtype) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: stash_type != input dtype"); + + Value cstEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr((double)epsilon)); + Value cstNumGroups = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(numGroups)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstNumGroups, scale, bias, cstEpsilon, + /*cudnn_enabled=*/cstFalse); + return success(); + }); + patterns.onOp( + "Optional", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::OptionalType resultType; + Value input; + + if (binder.getNumOperands() == 0) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented support for missing input element"); + + if (binder.tensorListOperand(input)) + if (binder.tensorOperand(input)) + return failure(); + + if (binder.optionalResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + input); + return success(); + }); + patterns.onOp("OptionalGetElement", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType tensorListResultType; + Torch::ValueTensorType tensorResultType; + Value input; + + if (binder.tensorListResultType(tensorListResultType)) { + if (binder.tensorResultType(tensorResultType)) + return failure(); + + if (binder.optionalTensorOperand(input)) { + if (binder.tensorOperand(input)) + return failure(); + + // It means the input is a tensor. + rewriter.replaceOp(binder.op, input); + return success(); + } + + // It means the input is an optional tensor. + rewriter.replaceOpWithNewOp( + binder.op, tensorResultType, input); + return success(); + } + + if (binder.optionalTensorListOperand(input)) { + if (binder.tensorListOperand(input)) + return failure(); + + // It means the input is a tensor list. + rewriter.replaceOp(binder.op, input); + return success(); + } + + // It means the input is an optional tensor list. + rewriter.replaceOpWithNewOp( + binder.op, tensorListResultType, input); + return success(); + }); + patterns.onOp( + "OptionalHasElement", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + if (binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure(binder.op, + "result type bind failed"); + + Value input; + bool output; + if (!binder.tensorListOperand(input) || !binder.tensorOperand(input) || + !binder.optionalTensorListOperand(input) || + !binder.optionalTensorOperand(input)) + output = true; + else + output = false; + + Value cstOutput = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr((int64_t)output)); + Value cstDtype = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr((int)torch_upstream::ScalarType::Bool)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + Value cstNone = rewriter.create(binder.getLoc()); + + Value dataList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{cstOutput}); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, dataList, /*dtype=*/cstDtype, + /*layout=*/cstNone, /*requires_grad=*/cstFalse); + return success(); + }); + patterns.onOp( + "NonMaxSuppression", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + SmallVector operands; + int64_t centerPointBox; + if (binder.tensorOperandsList(operands) || + binder.s64IntegerAttr(centerPointBox, "center_point_box", 0) || + binder.tensorResultType(resultType)) + return failure(); + + if (centerPointBox != 0 && centerPointBox != 1) + return rewriter.notifyMatchFailure( + binder.op, "expected center_point_box attribute to be 0 or 1"); + + // TODO: Support multiple batches and classes + // Squeeze the boxes and scores tensor. + // In Onnx, the shape of boxes is [BxNx4] while the + // torchvision expects it to be of shape [Nx4]. Similarly, for + // the scores tensor shape in Onnx is [BxCxN] while the + // torchvision expects it to be of shape [N]. + Value boxes = operands[0], scores = operands[1]; + FailureOr squeezedBoxes = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); + if (failed(squeezedBoxes)) + return rewriter.notifyMatchFailure(binder.op, + "failed to squeeze boxes tensor"); + FailureOr squeezedScores = + Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); + if (failed(squeezedScores)) + return rewriter.notifyMatchFailure(binder.op, + "failed to squeeze scores tensor"); + squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, + squeezedScores.value()); + if (failed(squeezedScores)) + return rewriter.notifyMatchFailure(binder.op, + "failed to squeeze scores tensor"); + boxes = squeezedBoxes.value(); + scores = squeezedScores.value(); + if (centerPointBox == 1) { + // When center_point_box is 1, the box data is supplied as + // [[x_center, y_center, width, height], ...]. Slice it to + // [[x_center, y_center], ...] and [[width, height], ...], + // calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate + // to [[x1, y1, x2, y2], ...] + auto boxesTensorType = + dyn_cast(boxes.getType()); + Value const0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value const1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value const2 = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value const4 = rewriter.create( + loc, rewriter.getI64IntegerAttr(4)); + Value const2F = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); + + // extract scaled ranges for regions of interest + auto sliceShape = SmallVector{Torch::kUnknownSize, 2}; + auto sliceTensorType = rewriter.getType( + sliceShape, boxesTensorType.getDtype()); + Value centers = rewriter.create( + loc, sliceTensorType, boxes, const1, const0, const2, const1); + Value sizes = rewriter.create( + loc, sliceTensorType, boxes, const1, const2, const4, const1); + Value halfSizes = rewriter.create( + loc, sizes.getType(), sizes, const2F); + Value x1y1s = rewriter.create( + loc, centers.getType(), centers, halfSizes, const1); + Value x2y2s = rewriter.create( + loc, centers.getType(), centers, halfSizes, const1); + + Type listElemType = boxesTensorType.getWithSizesAndDtype( + /*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + loc, listType, SmallVector{x1y1s, x2y2s}); + boxes = rewriter.create(loc, boxesTensorType, + tensorList, const1); + } + + // TODO: Support score_threshold input + // Filter out the boxes if the score < score_threshold + if (operands.size() == 5) { + Value scoreThreshold = rewriter.create( + loc, rewriter.getType(), operands[4]); + Value minScores = rewriter.create( + loc, + Torch::ValueTensorType::get(binder.op->getContext(), + SmallVector{}, + rewriter.getF32Type()), + scores); + minScores = rewriter.create( + loc, rewriter.getType(), minScores); + + Value scoresCond = rewriter.create( + loc, minScores, scoreThreshold); + rewriter.create( + loc, scoresCond, + rewriter.getStringAttr( + "unimplemented: score_threshold should be <= min(scores)")); + } + + // Get max_output_boxes_per_class and iou_threshold + Value cst0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value maxOutputBoxesPerClass = cst0; + Value iouThreshold = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0)); + if (operands.size() > 3 && + !isa(operands[3].getType())) { + iouThreshold = rewriter.create( + loc, rewriter.getType(), operands[3]); + } + if (operands.size() > 2 && + !isa(operands[2].getType())) { + maxOutputBoxesPerClass = rewriter.create( + loc, rewriter.getType(), operands[2]); + } + + auto nmsTy = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{-1}, + rewriter.getIntegerType(64, /*signed=*/true)); + Value result = rewriter.create( + loc, nmsTy, boxes, scores, iouThreshold); + + // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class + Value numOutputBoxes = + rewriter.create(loc, result, cst0); + Value boxesCond = rewriter.create( + loc, numOutputBoxes, maxOutputBoxesPerClass); + + auto nmsResultTy = Torch::ValueTensorType::get( + binder.op->getContext(), + SmallVector{resultType.getSizes()[0]}, + rewriter.getIntegerType(64, /*signed=*/true)); + auto ifSlice = rewriter.create( + loc, TypeRange({nmsResultTy}), boxesCond); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getThenRegion(), + ifSlice.getThenRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, + /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); + rewriter.create(loc, curResult); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifSlice.getElseRegion(), + ifSlice.getElseRegion().begin()); + + Value curResult = rewriter.create( + loc, nmsResultTy, result); + rewriter.create(loc, curResult); + } + result = ifSlice.getResult(0); + + // The result generated by torchvision.nms op is of shape [n], while the + // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor + // and make it of shape [n, 1] and then concatenate it with a zero + // tensor of shape [n, 2] to make it of shape [n, 3]. + FailureOr unsqueezedResult = + Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); + if (failed(unsqueezedResult)) + return rewriter.notifyMatchFailure( + binder.op, "failed to unsqueeze result tensor"); + result = unsqueezedResult.value(); + + numOutputBoxes = + rewriter.create(loc, result, cst0); + SmallVector zerosShapeValues{numOutputBoxes}; + zerosShapeValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(2))); + Value zerosShapeList = rewriter.create( + loc, + rewriter.getType( + rewriter.getType()), + zerosShapeValues); + std::optional> resultShape = + cast(result.getType()).getOptionalSizes(); + if (!resultShape.has_value()) + return rewriter.notifyMatchFailure( + binder.op, "expected result tensor to have shape"); + llvm::SmallVector zerosShape = {resultShape->front(), 2}; + auto zerosTy = Torch::ValueTensorType::get( + resultType.getContext(), zerosShape, resultType.getOptionalDtype()); + Value cstNone = rewriter.create(loc); + Value zeros = rewriter.create( + loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); + + Type listElemType = + cast(resultType) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + loc, listType, SmallVector{zeros, result}); + rewriter.replaceOpWithNewOp(binder.op, resultType, + tensorList, cst1); + return success(); + }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 55fb132989ca..ce65cb6ce1c4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -36,21 +36,24 @@ namespace { // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for // noop_with_empty_axes handling before that. -LogicalResult reducedSumImpl(OpBinder binder, - ConversionPatternRewriter &rewriter, Value data, - Torch::ValueTensorType resultType, - Value &storeResult, int64_t keepDims, - int64_t noop_with_empty_axes, - bool isIntermediateOp) { - +template +LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, + Value data, Torch::ValueTensorType resultType, + Value &storeResult, int64_t keepDims, + int64_t noop_with_empty_axes, + bool isIntermediateOp) { + + auto inputType = dyn_cast(data.getType()); + if (!inputType) + return failure(); SmallVector axesList; Value axesVal; if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = dyn_cast(data.getType()); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: expected input and result to have shapes"); - } + auto axesTy = dyn_cast(axesVal.getType()); + if (!axesTy || !axesTy.areAllSizesKnown() || axesTy.getSizes().size() > 1) + return failure(); + auto axesShape = axesTy.getSizes(); + uint64_t numAxes = (axesShape.empty()) ? 1 : axesShape.front(); if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { SmallVector inputShape{inputType.getSizes()}; @@ -77,22 +80,25 @@ LogicalResult reducedSumImpl(OpBinder binder, } else { reduceDims.push_back(i); if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) + resultShape[resultShapeCounter] == 1 && keepDims == 1) resultShapeCounter++; } } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } + if (reduceDims.size() == numAxes) { + for (auto i : reduceDims) { + axesList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else + binder.op->emitWarning( + "Number of inferred reduce dims, " + + std::to_string(reduceDims.size()) + + ", does not match the provided number of axes, " + + std::to_string(numAxes) + "."); } } if (axesList.empty()) { - Torch::BaseTensorType axesType = - cast(axesVal.getType()); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + if (axesTy.getSizes()[0] == Torch::kUnknownSize) return failure(); Value zero = rewriter.create( @@ -100,9 +106,8 @@ LogicalResult reducedSumImpl(OpBinder binder, rewriter.getI64IntegerAttr(0)); SmallVector selectSizes{1}; auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { + selectSizes, axesTy.getOptionalDtype()); + for (uint64_t i = 0; i < numAxes; ++i) { Value iv = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i)); @@ -117,41 +122,126 @@ LogicalResult reducedSumImpl(OpBinder binder, SmallVector axesInts; if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); + for (int64_t i : axesInts) { + axesList.push_back( + rewriter.create(binder.getLoc(), i)); } } // Do not include absolute value in the noop - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, storeResult); + if (axesList.empty() && noop_with_empty_axes == 1) { + if (!isIntermediateOp) + rewriter.replaceOp(binder.op, data); + else + storeResult = data; return success(); } + // if the axes list is still empty, reduce everything. + if (axesList.empty()) { + if (keepDims == 0 && !resultType.getSizes().empty()) + return rewriter.notifyMatchFailure( + binder.op, + "no axes provided & no keepdim: expected result to be rank zero."); + if (keepDims == 1 && + (resultType.getSizes().size() != inputType.getSizes().size() || + llvm::any_of(resultType.getSizes(), + [](int64_t size) { return size != 1; }))) + return rewriter.notifyMatchFailure( + binder.op, "no axes provided & keepdim: expected result to have all " + "dimensions equal to 1."); + for (uint64_t i = 0; i < inputType.getSizes().size(); i++) { + axesList.push_back( + rewriter.create(binder.getLoc(), i)); + } + } + Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), axesList); Value keepDimBool = rewriter.create(binder.getLoc(), keepDims); - Value dType = rewriter.create(binder.getLoc()); - // If we are using the ReducedSum as an intermediate op to be passed into + // If we are using the reduction op as an intermediate op to be passed into // another operation, we might not want to replace the Op. So we create a new // Op and store the result in a variable. + SmallVector operands = {data, dimValueList, keepDimBool}; + if (llvm::is_one_of()) + operands.push_back( + /*dtype=*/rewriter.create(binder.getLoc())); if (!isIntermediateOp) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/dType); + rewriter.replaceOpWithNewOp(binder.op, resultType, + operands); } else { - storeResult = rewriter.create( - binder.getLoc(), resultType, data, dimValueList, keepDimBool, - /*dtype=*/dType); + storeResult = rewriter.create(binder.getLoc(), + resultType, operands); } return success(); } + +Type getTorchScalarType( + /* forElementIn */ Torch::BaseTensorType givenTensorType, + /* using */ ConversionPatternRewriter &rewriter) { + auto elementTypeForGivenTensor = givenTensorType.getDtype(); + + if (isa(elementTypeForGivenTensor)) + return rewriter.getType(); + if (isa(elementTypeForGivenTensor)) + return rewriter.getType(); + + llvm_unreachable("dtype for given tensor expected to be either int or float"); +} + +Value extractTorchScalar( + /* at */ Location givenLoc, + /* from */ int64_t givenIndex, + /* in */ Value given1DTensor, + /* using */ ConversionPatternRewriter &rewriter) { + auto some1DTensorType = cast(given1DTensor.getType()); + + Type selectionTypeForSome1DTensor = some1DTensorType.getWithSizesAndDtype( + ArrayRef{1}, some1DTensorType.getOptionalDtype()); + + Value frontDim = rewriter.create(givenLoc, 0); + + Value selectionIndex = + rewriter.create(givenLoc, givenIndex); + + auto someTorchScalarType = getTorchScalarType(some1DTensorType, rewriter); + + Value selectionFromGiven1DTensor = rewriter.create( + givenLoc, selectionTypeForSome1DTensor, given1DTensor, frontDim, + selectionIndex); + + return rewriter.create(givenLoc, someTorchScalarType, + selectionFromGiven1DTensor); +} + +Value createScalarSublist( + /* at */ Location givenLoc, + /* movingForwardsThrough */ Value given1DTensor, + /* startingAt */ int64_t givenIndex, + /* using */ ConversionPatternRewriter &rewriter) { + auto some1DTensorType = cast(given1DTensor.getType()); + auto sizesOfSome1DTensor = some1DTensorType.getSizes(); + auto lengthOfFullList = sizesOfSome1DTensor[0]; + + SmallVector runningScalarSublist; + + for (int indexOfEachScalar = givenIndex; indexOfEachScalar < lengthOfFullList; + indexOfEachScalar++) { + Value eachScalar = extractTorchScalar(givenLoc, indexOfEachScalar, + given1DTensor, rewriter); + runningScalarSublist.push_back(eachScalar); + } + + auto someTorchScalarType = runningScalarSublist.front().getType(); + Type someTorchScalarListType = Torch::ListType::get(someTorchScalarType); + + return rewriter.create( + givenLoc, someTorchScalarListType, runningScalarSublist); +} } // namespace void mlir::torch::onnx_c::populateDefaultDomainQtoZ( @@ -165,6 +255,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); + auto loc = binder.getLoc(); Value operand = operands[0]; Value scale = operands[1]; Value zeropoint = operands[2]; @@ -176,42 +267,61 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure(binder.op, "requires known result dtype"); - if (scaleTy.getSizes().size() == 0) { - Type qTy = resultType.getDtype(); + auto resultETy = resultType.getDtype(); - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { - return rewriter.notifyMatchFailure(binder.op, - "unsupported result dtype"); - } + bool rank0 = scaleTy.getSizes().size() == 0; + bool length1 = + scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1; - auto qTensorTy = rewriter.getType( - resultType.getOptionalSizes(), qTy); - auto torchqTy = Torch::getScalarTypeForType(qTy); + if (!rank0 && !length1) + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: non-scalar scale"); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - static_cast(torchqTy))); - - scale = rewriter.create( - binder.getLoc(), rewriter.getType(), scale); - zeropoint = rewriter.create( - binder.getLoc(), rewriter.getType(), zeropoint); - - auto quantize = rewriter.create( - binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); - rewriter.replaceOpWithNewOp( - binder.op, resultType, quantize); + auto qTensorTy = getQTorchTypeFromTorchIntType(resultType); + if (!qTensorTy) { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype()); + + Value tyConst = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); + + scale = rewriter.create( + loc, rewriter.getType(), scale); + + bool fpResult = isa(resultETy); + Type zeropointTy = rewriter.getType(); + if (fpResult) + zeropointTy = rewriter.getType(); + zeropoint = + rewriter.create(loc, zeropointTy, zeropoint); + + if (fpResult) { + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + Value one = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + Value div = rewriter.create( + loc, operand.getType(), operand, scale); + Value add = rewriter.create( + loc, operand.getType(), div, zeropoint, one); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, add, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); return success(); } - return failure(); + auto quantize = rewriter.create( + loc, qTensorTy, operand, scale, zeropoint, tyConst); + rewriter.replaceOpWithNewOp(binder.op, resultType, + quantize); + return success(); }); patterns.onOp( "QLinearConv", 1, @@ -311,8 +421,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter.create(binder.getLoc(), cTy, c); - cTy = dyn_cast( - getQTorchTypeFromTorchIntType(resultType)); + cTy = getQTorchTypeFromTorchIntType(resultType); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( @@ -469,6 +578,44 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); + patterns.onOp("RNN", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + return OnnxRnnExpander(binder, rewriter); + }); + patterns.onOp( + "Scatter", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + int64_t axis; + if (binder.s64IntegerAttr(axis, "axis", {})) + return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); + + Torch::ValueTensorType resultTy; + Value data, indices, updates; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorOperandAtIndex(updates, 2) || + binder.tensorResultType(resultTy)) + return failure(); + + auto dataTy = cast(data.getType()), + indicesTy = cast(indices.getType()), + updatesTy = cast(updates.getType()); + + int64_t dataRank = dataTy.getSizes().size(), + indicesRank = indicesTy.getSizes().size(), + updatesRank = updatesTy.getSizes().size(); + + if ((dataRank < 1) || (indicesRank < 1) || (updatesRank < 1) || + (axis < -dataRank) || (axis >= dataRank)) + return failure(); + + Value axisValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + + rewriter.replaceOpWithNewOp( + binder.op, resultTy, data, axisValue, indices, updates); + + return success(); + }); patterns.onOp( "ScatterElements", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -483,6 +630,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); + auto loc = binder.getLoc(); Value data = valList[0]; Value indices = valList[1]; Value updates = valList[2]; @@ -493,9 +641,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cast(data.getType()).getSizes().size(); Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), + loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(1)); + + Value axisSize = rewriter.create( + binder.getLoc(), rewriter.getType(), data, + constAxis); + + auto indicesTy = cast(indices.getType()); + Value indicesAdd = rewriter.create( + loc, indicesTy, indices, axisSize, one); + + Value inputNeg = rewriter.create( + loc, + rewriter.getType(indicesTy.getSizes(), + rewriter.getI1Type()), + indices, zero); + + indices = rewriter.create( + loc, indicesTy, inputNeg, indicesAdd, indices); + if (reduction == "none") { rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, updates); @@ -504,18 +676,59 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // TODO: Implement max and min cases if (reduction == "mul") { - reduction = "multiply"; + reduction = "prod"; } else if (reduction == "max" || reduction == "min") { return rewriter.notifyMatchFailure( binder.op, "max/min reduction unsupported for scatter elements"); + } else if (reduction == "add") { + reduction = "sum"; } Value cstStrReduction = rewriter.create(binder.getLoc(), reduction); - - rewriter.replaceOpWithNewOp( + Value cstTrue = + rewriter.create(binder.getLoc(), true); + rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, updates, - cstStrReduction); + cstStrReduction, cstTrue); + return success(); + }); + patterns.onOp( + "SequenceConstruct", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallVector operands; + Torch::ListType resultType; + + if (binder.tensorOperands(operands, binder.getNumOperands()) || + binder.tensorListResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operands); + return success(); + }); + patterns.onOp( + "SequenceLength", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // onnx.SequenceLength takes a sequence(list) of tensors, and returns + // a zero rank tensor with the length. + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorListOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value none = rewriter.create(binder.getLoc()); + + Value len = rewriter.create( + binder.getLoc(), rewriter.getType(), x); + + // AtenLenTOp returns a torch.int, so we have to + // put that in a tensor. + rewriter.replaceOpWithNewOp( + binder.op, resultType, len, none, none, cstFalse); + return success(); }); patterns.onOp( @@ -867,25 +1080,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); - patterns.onOp("ReduceL1", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - int64_t keepDims, noop_with_empty_axes; - Value operand; - if (binder.tensorOperandAtIndex(operand, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceL1", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t keepDims, noop_with_empty_axes; + Value operand; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - Value data = rewriter.create( - binder.getLoc(), operand.getType(), operand); + Value data = rewriter.create( + binder.getLoc(), operand.getType(), operand); - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/operand, keepDims, - noop_with_empty_axes, false); - }); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/operand, keepDims, noop_with_empty_axes, false); + }); patterns.onOp( "ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -903,9 +1116,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value squareOfOperand = rewriter.create( binder.getLoc(), operand.getType(), operand, operand); - auto reducedSum = - reducedSumImpl(binder, rewriter, squareOfOperand, resultType, - operand, keepDims, noop_with_empty_axes, true); + auto reducedSum = reduceOpImpl( + binder, rewriter, squareOfOperand, resultType, operand, keepDims, + noop_with_empty_axes, true); if (failed(reducedSum)) return rewriter.notifyMatchFailure( binder.op, @@ -940,33 +1153,97 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp("ReduceLogSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); + patterns.onOp( + "ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); - auto reducedSumBool = - reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); + auto reducedSumBool = reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data); - return success(); - }); - patterns.onOp("ReduceSum", 1, + rewriter.replaceOpWithNewOp(binder.op, resultType, + data); + return success(); + }); + patterns.onOp( + "ReduceLogSumExp", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + // out = Log(reducesum(exp(data))) + Value castDType = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); + Value noneVal = rewriter.create(binder.getLoc()); + Value constFalse = + rewriter.create(binder.getLoc(), false); + auto size = + dyn_cast(data.getType()).getOptionalSizes(); + auto f64ResultType = rewriter.getType( + size, rewriter.getF64Type()); + Value dataCast = rewriter.create( + binder.getLoc(), f64ResultType, data, castDType, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + Value dataExp = rewriter.create( + binder.getLoc(), f64ResultType, dataCast); + auto f64ReduceType = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF64Type()); + auto reducedSumBool = reduceOpImpl( + binder, rewriter, dataExp, f64ReduceType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + Value finalResult = rewriter.create( + binder.getLoc(), f64ReduceType, data); + Value resultDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), resultType.getDtype()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, finalResult, resultDtype, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); + }); + patterns.onOp("ReduceSumSquare", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -978,11 +1255,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); + Value dataSquare = rewriter.create( + binder.getLoc(), data.getType(), data, data); + + return reduceOpImpl( + binder, rewriter, dataSquare, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, + false); }); - patterns.onOp("ReduceSumSquare", 1, + patterns.onOp("ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -994,140 +1275,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - Value dataSquare = rewriter.create( - binder.getLoc(), data.getType(), data, data); - - return reducedSumImpl(binder, rewriter, dataSquare, - resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); + Value reduceSum = data; + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/reduceSum, keepDims, noop_with_empty_axes, + false); }); - patterns.onOp( - "ReduceMean", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - - SmallVector axesList; - - Value axesVal; - if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto inputType = dyn_cast(data.getType()); - if (!inputType.hasSizes() || !resultType.hasSizes()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: expected input and result to have shapes"); - } - - // If the input shape and result shape is statically known then the - // list of dims to be squeezed can be derived from those shapes. As a - // result, we don't have to wait for the dim values to be known at - // runtime which is also expected by the downstream pipeline. - if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { - SmallVector inputShape{inputType.getSizes()}; - SmallVector resultShape{resultType.getSizes()}; - if (llvm::equal(inputShape, resultShape)) { - // Case: none of the dimension is reduced. - rewriter.replaceOp(binder.op, data); - return success(); - } - if (areAllElementsDistinct(inputShape)) { - // The check for the input shape elements to be distinct is added - // for the cases like: - // Input: [3, 2, 2] -> Output: [3, 2] - // For the above case, from the input and output shape it can't be - // inferred whether the dim:1 is reduced or dim:2. To avoid these - // type of cases, the check has been placed. - SmallVector reduceDims; - unsigned resultShapeCounter = 0; - for (unsigned i = 0; i < inputShape.size(); i++) { - if (resultShapeCounter < resultShape.size() && - inputShape[i] == resultShape[resultShapeCounter]) { - resultShapeCounter++; - } else { - reduceDims.push_back(i); - if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1) - resultShapeCounter++; - } - } - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - } - } - - if (axesList.empty()) { - Torch::BaseTensorType axesType = - cast(axesVal.getType()); - auto axesTy = dyn_cast(axesVal.getType()); - auto axesShape = axesTy.getSizes(); - if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) - return failure(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - SmallVector selectSizes{1}; - auto selType = rewriter.getType( - selectSizes, axesType.getOptionalDtype()); - int64_t numAxes = axesShape[0]; - for (int64_t i = 0; i < numAxes; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - binder.getLoc(), selType, axesVal, zero, iv); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - } - - SmallVector axesInts; - if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(axesInts[i])); - axesList.push_back(iv); - } - } - - // deal with case when axes is empty - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - Value noneVal = rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/noneVal); - return success(); - }); patterns.onOp( "ReduceMax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // AtenAmaxOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1140,7 +1299,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto dataTy = cast(data.getType()); Torch::IntType torchIntTy = rewriter.getType(); - // If any of the input dims are 0 we set to the upper limit: + // If any of the input dims are 0 we set to the lower limit: if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == Torch::kUnknownSize; }) || @@ -1148,7 +1307,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto dty = dataTy.getDtype(); Value scalar; if (FloatType fpTy = dyn_cast(dty)) { - auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + auto inf = + APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true); scalar = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), @@ -1156,14 +1316,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (IntegerType intTy = dyn_cast(dty)) { - auto mx = + auto minInt = intTy.isSigned() - ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) - : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + ? APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMinValue(intTy.getIntOrFloatBitWidth()); scalar = rewriter.create( binder.getLoc(), torchIntTy, rewriter.getIntegerAttr(rewriter.getIntegerType(64), - mx.getSExtValue())); + minInt.getSExtValue())); } llvm::SmallVector fillDims; @@ -1191,87 +1351,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - // Previous version of the operation had the axes as an attribute: - SmallVector axesList; - llvm::SmallVector axesAttr; - if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { - for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, - rewriter.getI64IntegerAttr(axesAttr[i]))); - } - } - - // Extract the axes values from the axes operand: - if (!binder.tensorOperandAtIndex(axes, 1)) { - Torch::BaseTensorType axesType = - cast(axes.getType()); - SmallVector selectSizes{1}; - Type selectResultType = axesType.getWithSizesAndDtype( - selectSizes, axesType.getOptionalDtype()); - auto sizes = axesType.getSizes(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - - // Extract the value of each axes: - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } - - // Handle the noop case: - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); - return success(); - } - - // Deal with case when no axes arg is passed but not a noop: - if (axesList.empty()) { - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - } - - // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - for (Value &axes : axesList) { - Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); - return success(); + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); patterns.onOp( @@ -1280,7 +1362,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // AtenAminOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1344,101 +1425,59 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - // Previous version of the operation had the axes as an attribute: - SmallVector axesList; - llvm::SmallVector axesAttr; - if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { - for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, - rewriter.getI64IntegerAttr(axesAttr[i]))); - } - } + return reduceOpImpl( + binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, false); + }); - // Extract the axes values from the axes operand: - if (!binder.tensorOperandAtIndex(axes, 1)) { - Torch::BaseTensorType axesType = - cast(axes.getType()); - SmallVector selectSizes{1}; - Type selectResultType = axesType.getWithSizesAndDtype( - selectSizes, axesType.getOptionalDtype()); - auto sizes = axesType.getSizes(); + patterns.onOp( + "Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + int64_t start, end; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + auto inputType = dyn_cast(operand.getType()); + if (!inputType || !inputType.hasSizes()) + return failure(); - // Extract the value of each axes: - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - axesList.push_back(dim); - } - } + int64_t inputRank = inputType.getSizes().size(); - // Handle the noop case: - if (axesList.empty() && noop_with_empty_axes) { - rewriter.replaceOp(binder.op, data); + if (binder.s64IntegerAttr(start, "start", 0) || + binder.s64IntegerAttr(end, "end", inputRank)) + return failure(); + + auto shapeType = Torch::ValueTensorType::get( + binder.op->getContext(), SmallVector{inputRank}, + resultType.getOptionalDtype()); + Value shape = rewriter.create( + binder.getLoc(), shapeType, operand); + + if (inputRank == 0) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, shape); return success(); } - // Deal with case when no axes arg is passed but not a noop: - if (axesList.empty()) { - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } + if (start == 0 && end == inputRank) { + rewriter.replaceOp(binder.op, shape); + return success(); } - // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getI64IntegerAttr(0)); - for (Value &axes : axesList) { - Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); - } + Value sv = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(start)); + Value ev = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(end)); + Value step = rewriter.create(binder.getLoc(), 1); + Value dim = rewriter.create(binder.getLoc(), 0); - Value dimValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); - Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); + rewriter.replaceOpWithNewOp( + binder.op, resultType, shape, dim, sv, ev, step); return success(); }); - patterns.onOp("Shape", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); - patterns.onOp("Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1475,7 +1514,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.s64IntegerAttr(axis, "axis", 0)) return rewriter.notifyMatchFailure(binder.op, "Failed to get axis attribute"); - if (binder.s64IntegerAttr(numOutputs, "num_outputs", 2)) + + numOutputs = binder.op->getNumResults(); + if (binder.s64IntegerAttr(numOutputs, "num_outputs", numOutputs)) return rewriter.notifyMatchFailure( binder.op, "Failed to get num_outputs attribute"); @@ -1621,8 +1662,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( }); patterns.onOp( - "Transpose", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Transpose", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { auto loc = binder.getLoc(); Torch::ValueTensorType resultType; Value operand; @@ -1789,10 +1829,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Axes should be the same size of starts and ends"); } - auto stepsTy = steps.getType() - .cast() - .toBuiltinTensor() - .dyn_cast(); + auto stepsTy = dyn_cast( + cast(steps.getType()).toBuiltinTensor()); if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0))) return rewriter.notifyMatchFailure( @@ -2318,6 +2356,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( exp); return success(); }); + patterns.onOp("Softsign", 22, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || + binder.tensorResultType(resultType)) { + return failure(); + } + + Value absX = rewriter.create( + binder.getLoc(), resultType, input); + + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + + Value absXPlusOne = rewriter.create( + binder.getLoc(), resultType, absX, constOne, constOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, absXPlusOne); + return success(); + }); patterns.onOp( "Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -2353,7 +2413,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value input; float alpha; if (binder.tensorOperand(input) || - binder.f32FloatAttr(alpha, "alpha", 1.0)) { + binder.f32FloatAttr(alpha, "alpha", 1.0) || + binder.tensorResultType(resultType)) { return failure(); } Value cstAlpha = rewriter.create( @@ -2367,70 +2428,252 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( - "RandomNormal", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - SmallString<64> name("torch.onnx.seed"); - auto seedAttr = binder.op->getAttr(name); - if (seedAttr) + "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + std::string mode, nearest_mode, coordTfMode; + Value noneVal = rewriter.create(binder.getLoc()); + + if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { return rewriter.notifyMatchFailure( binder.op, - "unimplemented: support not present for seed attribute"); - - Torch::ValueTensorType resultType; - int64_t dtypeIntOnnx; - float mean, scale; - SmallVector shape; - if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || - binder.f32FloatAttr(mean, "mean", 0.0) || - binder.f32FloatAttr(scale, "scale", 1.0) || - binder.s64IntegerArrayAttr(shape, "shape", {}) || - binder.tensorResultType(resultType)) { - return failure(); + "unimplemented: support not present for antialias attribute"); } - - std::optional dtypeIntTorch = - onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); - if (!dtypeIntTorch.has_value()) { + if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( binder.op, - "unimplemented support for the given dtype conversion"); + "unimplemented: support not present for axes attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } + if (auto attr = + binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "keep_aspect_ratio_policy attribute"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); - Value shapeList = createConstantIntList(binder, rewriter, shape); - Value cstNone = rewriter.create(binder.getLoc()); + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(mode, "mode", "nearest") || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", + "round_prefer_floor")) + return failure(); + if (coordTfMode == "tf_crop_and_resize") + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: coordinate transformation mode: " + "tf_crop_and_resize"); - Value self = rewriter.create( - binder.op->getLoc(), resultType, shapeList, - /*dtype=*/constDtype, - /*layout=*/cstNone, - /*device=*/cstNone, /*pinMemory=*/cstNone, - /*memoryFormat=*/cstNone); + if (mode == "nearest" && coordTfMode != "asymmetric" && + coordTfMode != "half_pixel") { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for coord tf mode " + "except asymmetric and half_pixel"); + } - Value cstMean = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getFloatAttr(rewriter.getF64Type(), mean)); - Value cstStd = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getFloatAttr(rewriter.getF64Type(), scale)); + unsigned rank = dyn_cast(operands[0].getType()) + .getSizes() + .size(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, self, cstMean, cstStd, - /*generator=*/cstNone); - return success(); - }); - patterns.onOp( - "RandomNormalLike", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - SmallString<64> name("torch.onnx.seed"); - auto seedAttr = binder.op->getAttr(name); - if (seedAttr) - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: support not present for seed attribute"); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - Torch::ValueTensorType resultType; + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value cstTrue = + rewriter.create(binder.getLoc(), true); + Value modeStrValue; + + auto extract = [&rewriter, &binder](Value x, Value v) { + auto xTy = cast(x.getType()); + Type extractTy = rewriter.getType(); + if (isa(xTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + auto getValueList = [&](Value operand) { + SmallVector itemList; + auto sizes = + dyn_cast(operand.getType()).getSizes(); + Torch::BaseTensorType operandType = + cast(operand.getType()); + + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = operandType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); + + MLIRContext *context = binder.op->getContext(); + for (int i = 2; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value ext = rewriter.create( + binder.getLoc(), selectResultType, operand, zero, selectIndex); + Value item = extract(operand, ext); + itemList.push_back(item); + } + auto xTy = cast(operand.getType()); + Value ValueList; + if (isa(xTy.getDtype())) { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(context)), itemList); + } else { + ValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::FloatType::get(context)), itemList); + } + return ValueList; + }; + + Value scalesValueList = noneVal; + Value sizesValueList = noneVal; + Value alignCorners = + coordTfMode == "align_corners" ? cstTrue : cstFalse; + if (mode == "cubic") { + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: bicubic mode"); + } + // supported modes: + // bilinear (half_pixel), bilinear with align_corners, + // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest + // (asymmetric), nearest with align_corners, nearest_half_pixel, + // nearest_pytorch_half_pixel + if (mode == "linear") { + std::string modeStr; + switch (rank) { + case 3: + modeStr = "linear"; + break; + case 4: + modeStr = "bilinear"; + break; + case 5: + modeStr = "trilinear"; + break; + default: + return failure(); + } + // Confusingly enough, the default coordTfMode for pytorch bilinear + // mode is apparently half_pixel, NOT pytorch_half_pixel + if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); + } + if (mode == "nearest") { + std::string modeStr = "nearest"; + // The default coordTfMode for pytorch with mode = nearest is + // apparently asymmetric + if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + if (nearest_mode != "floor") + modeStr = modeStr + "," + nearest_mode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); + } + if (operands.size() < 4) { + Value scaleOperand = operands[2]; + scalesValueList = getValueList(scaleOperand); + sizesValueList = noneVal; + } else { + Value sizeOperand = operands[3]; + scalesValueList = noneVal; + sizesValueList = getValueList(sizeOperand); + } + if (isa(scalesValueList.getType()) && + isa(sizesValueList.getType())) { + return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); + } + rewriter + .replaceOpWithNewOp( + binder.op, resultType, operands[0], sizesValueList, + scalesValueList, modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); + patterns.onOp( + "RandomNormal", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; + int64_t dtypeIntOnnx; + float mean, scale; + SmallVector shape; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.f32FloatAttr(mean, "mean", 0.0) || + binder.f32FloatAttr(scale, "scale", 1.0) || + binder.s64IntegerArrayAttr(shape, "shape", {}) || + binder.tensorResultType(resultType)) { + return failure(); + } + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, shape); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType, shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + Value cstMean = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), mean)); + Value cstStd = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), scale)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, cstMean, cstStd, + /*generator=*/cstNone); + return success(); + }); + patterns.onOp( + "RandomNormalLike", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallString<64> name("torch.onnx.seed"); + auto seedAttr = binder.op->getAttr(name); + if (seedAttr) + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + + Torch::ValueTensorType resultType; int64_t dtypeIntOnnx; float mean, scale; SmallVector shape; @@ -2590,7 +2833,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scores, labels, weight; if (binder.tensorOperandAtIndex(scores, 0) || binder.tensorOperandAtIndex(labels, 1) || - binder.s64IntegerAttr(ignoreIndex, "ignore_index ", -100) || + binder.s64IntegerAttr(ignoreIndex, "ignore_index", -100) || binder.customOpNameStringAttr(reduction, "reduction", "mean") || binder.tensorResultTypeAtIndex(resultType, 0)) { return failure(); @@ -2639,31 +2882,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( }); patterns.onOp( "Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + Torch::ValueTensorType outputTensorType; llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; - Value noneVal = rewriter.create(binder.getLoc()); + int64_t antialias, exclude_outside; + float extrapolation_value, cubic_coeff_a; - if (auto attr = binder.op->getAttr("torch.onnx.antialias")) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented: support not present for antialias attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.axes")) { return rewriter.notifyMatchFailure( binder.op, "unimplemented: support not present for axes attribute"); } - if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "exclude_outside attribute"); - } - if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for " - "extrapolation_value attribute"); - } if (auto attr = binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) { return rewriter.notifyMatchFailure( @@ -2672,121 +2901,1944 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || + binder.tensorResultType(outputTensorType) || binder.customOpNameStringAttr(mode, "mode", "nearest") || binder.customOpNameStringAttr( coordTfMode, "coordinate_transformation_mode", "half_pixel") || - binder.customOpNameStringAttr(nearest_mode, "nearest_mode", "")) + binder.s64IntegerAttr(antialias, "antialias", 0) || + binder.s64IntegerAttr(exclude_outside, "exclude_outside", 0) || + binder.f32FloatAttr(extrapolation_value, "extrapolation_value", + 0.0) || + binder.customOpNameStringAttr(nearest_mode, "nearest_mode", + "round_prefer_floor") || + binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); - if (mode == "nearest" && nearest_mode != "floor") { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: support not present for nearest_mode " - "except floor"); - } + int64_t const /* */ batchDim = 0; + int64_t const /**/ channelDim = 1; - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + SmallVector nonResizableDims{ + batchDim, + channelDim, + }; - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value cstTrue = - rewriter.create(binder.getLoc(), true); - Value modeStrValue; + Value inputTensor = operands[0]; + auto inputTensorType = + cast(inputTensor.getType()); + auto sizesOfInputTensor = inputTensorType.getSizes(); + auto sizesOfOutputTensor = outputTensorType.getSizes(); - auto extract = [&rewriter, &binder](Value x, Value v) { - auto xTy = x.getType().cast(); - Type extractTy = rewriter.getType(); - if (isa(xTy.getDtype())) - extractTy = rewriter.getType(); + auto unknownSize = Torch::kUnknownSize; - return rewriter.create(binder.getLoc(), extractTy, - v); - }; + // Compile-time check for dimensions of static size + for (auto &eachDim : nonResizableDims) { + auto eachSizeOfInputTensor = sizesOfInputTensor[eachDim]; + auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDim]; - auto getValueList = [&](Value operand) { - SmallVector itemList; - auto sizes = - dyn_cast(operand.getType()).getSizes(); - Torch::BaseTensorType operandType = - operand.getType().cast(); + if (eachSizeOfInputTensor == unknownSize || + eachSizeOfOutputTensor == unknownSize) + continue; + if (eachSizeOfInputTensor == eachSizeOfOutputTensor) + continue; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = operandType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); + auto resizingIntentErrorMessage = + "unsupported: non-trivial intent to resize dimension: " + + std::to_string(eachDim); - MLIRContext *context = binder.op->getContext(); - for (int i = sizes[0] - 2; i < sizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value ext = rewriter.create( - binder.getLoc(), selectResultType, operand, zero, selectIndex); - Value item = extract(operand, ext); - itemList.push_back(item); - } - auto xTy = operand.getType().cast(); - Value ValueList; - if (isa(xTy.getDtype())) { - ValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(context)), itemList); - } else { - ValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::FloatType::get(context)), itemList); - } - return ValueList; + return rewriter.notifyMatchFailure(binder.op, + resizingIntentErrorMessage); }; - Value scalesValueList = noneVal; - Value sizesValueList = noneVal; + if (antialias != 0) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for antialias attribute"); + } + if (exclude_outside != 0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "exclude_outside attribute"); + } + if (extrapolation_value != 0.0) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for " + "extrapolation_value attribute"); + } + if (coordTfMode == "tf_crop_and_resize") + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: coordinate transformation mode: " + "tf_crop_and_resize"); + + if (mode == "nearest" && coordTfMode != "asymmetric" && + coordTfMode != "half_pixel") { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: support not present for coord tf mode " + "except asymmetric and half_pixel"); + } + + if (mode == "cubic" && cubic_coeff_a != -0.75) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: cubic coeff must be -0.75"); + } + + auto loc = binder.getLoc(); + + Value cstFalse = rewriter.create(loc, false); + Value cstTrue = rewriter.create(loc, true); + Value modeStrValue; + Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; - if (mode == "cubic") { - return rewriter.notifyMatchFailure(binder.op, - "unimplemented: bicubic mode"); + std::string modeStr = "cubic"; + if (coordTfMode != "half_pixel") + modeStr = modeStr + "_" + coordTfMode; + modeStrValue = rewriter.create(loc, modeStr); } + + auto rankOfInputTensor = sizesOfInputTensor.size(); + + // supported modes: + // bilinear (half_pixel), bilinear with align_corners, + // bilinear_pytorch_half_pixel, bilinear_asymmetric nearest + // (asymmetric), nearest with align_corners, nearest_half_pixel, + // nearest_pytorch_half_pixel if (mode == "linear") { - modeStrValue = rewriter.create(binder.getLoc(), - "bilinear"); - if (operands.size() < 4) { - Value scaleOperand = operands[2]; - scalesValueList = getValueList(scaleOperand); - sizesValueList = noneVal; - } else { - Value sizeOperand = operands[3]; - scalesValueList = noneVal; - sizesValueList = getValueList(sizeOperand); + std::string modeStr; + switch (rankOfInputTensor) { + case 3: + modeStr = "linear"; + break; + case 4: + modeStr = "bilinear"; + break; + case 5: + modeStr = "trilinear"; + break; + default: + return failure(); } + // Confusingly enough, the default coordTfMode for pytorch bilinear + // mode is apparently half_pixel, NOT pytorch_half_pixel + if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + modeStrValue = rewriter.create(loc, modeStr); } if (mode == "nearest") { - modeStrValue = - rewriter.create(binder.getLoc(), "nearest"); - if (operands.size() < 4) { - Value scaleOperand = operands[2]; - scalesValueList = getValueList(scaleOperand); - sizesValueList = noneVal; - } else { - Value sizesOperand = operands[3]; - scalesValueList = noneVal; - sizesValueList = getValueList(sizesOperand); - } - } - if (scalesValueList.getType().isa() && - sizesValueList.getType().isa()) { - return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); + std::string modeStr = "nearest"; + // The default coordTfMode for pytorch with mode = nearest is + // apparently asymmetric + if (coordTfMode != "asymmetric" && coordTfMode != "align_corners") + modeStr = (modeStr + "_") + coordTfMode; + if (nearest_mode != "floor" && nearest_mode != "") + modeStr = modeStr + "," + nearest_mode; + modeStrValue = rewriter.create(loc, modeStr); } - rewriter - .replaceOpWithNewOp( - binder.op, resultType, operands[0], sizesValueList, - scalesValueList, modeStrValue, - /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, - /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, - /*Torch_BoolType:$antialias*/ cstFalse); + + auto numberOfOperands = operands.size(); + + Type boolType = rewriter.getType(); + + int64_t assumedForemostSpatialDim = 1 + nonResizableDims.back(); + + Value supportedScaleFactors; + Value supportedSizes; + + Value noneVal = rewriter.create(loc); + + if (numberOfOperands == 3) { + Value proposedScaleFactors = operands[2]; + + Value scaleIdentity = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + + // run-time scale factor check for dynamic sizes + for (auto &eachDim : nonResizableDims) { + Value eachProposedScaleFactor = extractTorchScalar( + loc, eachDim, proposedScaleFactors, rewriter); + + Value eachScaleFactorIsIdentity = + rewriter.create( + loc, boolType, eachProposedScaleFactor, scaleIdentity); + + auto errorMessageForEachDim = + "Unsupported: non-trivial scale factor for dimension " + + std::to_string(eachDim); + + rewriter.create( + loc, eachScaleFactorIsIdentity, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + + supportedScaleFactors = createScalarSublist( + loc, proposedScaleFactors, assumedForemostSpatialDim, rewriter); + supportedSizes = noneVal; + } else if (numberOfOperands == 4) { + Value proposedSizes = operands[3]; + + // run-time target size check for dynamic sizes + for (auto &eachDimAsInt : nonResizableDims) { + Value eachDimAsValue = + rewriter.create(loc, eachDimAsInt); + + Value eachSizeOfInputTensor = rewriter.create( + loc, inputTensor, eachDimAsValue); + + Value eachProposedSize = + extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter); + + Value eachProposedSizeIsTrivial = + rewriter.create( + loc, boolType, eachProposedSize, eachSizeOfInputTensor); + + auto errorMessageForEachDim = + "Unsupported: non-trivial resizing of dimension " + + std::to_string(eachDimAsInt); + + rewriter.create( + loc, eachProposedSizeIsTrivial, + rewriter.getStringAttr(errorMessageForEachDim)); + }; + + supportedScaleFactors = noneVal; + supportedSizes = createScalarSublist( + loc, proposedSizes, assumedForemostSpatialDim, rewriter); + } else + return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode"); + + rewriter + .replaceOpWithNewOp( + binder.op, outputTensorType, inputTensor, supportedSizes, + supportedScaleFactors, modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ alignCorners, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); + patterns.onOp( + "RoiAlign", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // operands = input, rois, batch_indices + SmallVector operands; + std::string coordTfMode, mode; + int64_t outHInt, outWInt, samplingRatioInt; + float spatialScaleFloat; + Torch::ValueTensorType resultType; + if (binder.tensorOperands(operands, 3) || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(mode, "mode", "avg") || + binder.s64IntegerAttr(outHInt, "output_height", 1) || + binder.s64IntegerAttr(outWInt, "output_width", 1) || + binder.s64IntegerAttr(samplingRatioInt, "sampling_ratio", 0) || + binder.f32FloatAttr(spatialScaleFloat, "spatial_scale", 1.0f) || + binder.tensorResultType(resultType)) + return failure(); + Value input = operands[0]; + Value rois = operands[1]; + Value batchIndices = operands[2]; + + // the torchvision roi_pool op does not support these features: + if (mode == "max" && + (coordTfMode != "half_pixel" || samplingRatioInt != 0)) + return rewriter.notifyMatchFailure( + binder.op, "unsupported: roi max pooling without default " + "coordTfMode and sampling_ratio"); + + Location loc = binder.getLoc(); + // concatenate the batchIndices to the rois to get rois as a num_roisx5 + // tensor. The batchIndices tensor is an int64 tensor, and needs to be + // converted to float before concatenation. + auto roisType = dyn_cast(rois.getType()); + if (!roisType || !roisType.hasSizes()) + return failure(); + Value cstDim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + FailureOr unsqueezeIndices = + Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim); + if (failed(unsqueezeIndices)) + return failure(); + batchIndices = unsqueezeIndices.value(); + auto batchIndicesType = + cast(batchIndices.getType()); + Value dTypeInt = + Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype()); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value newBatchIndices = rewriter.create( + loc, + batchIndicesType.getWithSizesAndDtype( + batchIndicesType.getOptionalSizes(), + roisType.getOptionalDtype()), + batchIndices, dTypeInt, cstFalse, cstFalse, none); + SmallVector roiSizes(roisType.getSizes()); + roiSizes.back() = 5; + auto catType = rewriter.getType( + roiSizes, roisType.getDtype()); + Type listElemType = + roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois}); + Value newRois = + rewriter.create(loc, catType, tensorList, cstDim); + + // make constants from attributes + Value cstSpatialScale = rewriter.create( + loc, rewriter.getF64FloatAttr(spatialScaleFloat)); + Value pooledHeight = rewriter.create( + loc, rewriter.getI64IntegerAttr(outHInt)); + Value pooledWidth = rewriter.create( + loc, rewriter.getI64IntegerAttr(outWInt)); + // this is for consistency with the default pytorch sampling ratio value + samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt; + Value samplingRatio = rewriter.create( + loc, rewriter.getI64IntegerAttr(samplingRatioInt)); + bool aligned = coordTfMode == "half_pixel"; + Value cstAligned = rewriter.create(loc, aligned); + + if (mode == "avg") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, newRois, cstSpatialScale, + pooledHeight, pooledWidth, samplingRatio, cstAligned); + return success(); + } + // mode == "max" + auto indicesType = resultType.getWithSizesAndDtype( + resultType.getOptionalSizes(), batchIndicesType.getDtype()); + auto roiPool = rewriter.create( + loc, TypeRange{resultType, indicesType}, input, newRois, + cstSpatialScale, pooledHeight, pooledWidth); + rewriter.replaceOp(binder.op, roiPool.getResult(0)); + return success(); + }); + patterns.onOp( + "SpaceToDepth", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t blockSize; + std::string mode; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(blockSize, "blocksize") || + binder.customOpNameStringAttr(mode, "mode", "DCR") || + binder.tensorResultType(resultType)) + return failure(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + SmallVector inputSizes{inputTy.getSizes()}; + if (inputSizes.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input rank to be 4"); + } + + Value b = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + Value c = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1))); + Value h = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2))); + Value w = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value hDivBlockSize = rewriter.create( + binder.getLoc(), h, cstBlockSize); + Value wDivBlockSize = rewriter.create( + binder.getLoc(), w, cstBlockSize); + hDivBlockSize = rewriter.create(binder.getLoc(), + hDivBlockSize); + wDivBlockSize = rewriter.create(binder.getLoc(), + wDivBlockSize); + + // The implementation is as follows: + // tmp = np.reshape( + // x, [b, c, h // blocksize, blocksize, w // blocksize, blocksize] + // ) + // tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) + // y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // + // blocksize]) + Value reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, c, hDivBlockSize, cstBlockSize, + wDivBlockSize, cstBlockSize}); + int64_t hDivBlockSizeInt = inputSizes[2] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[2] / blockSize; + int64_t wDivBlockSizeInt = inputSizes[3] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[3] / blockSize; + SmallVector reshapeSizesInt{inputSizes[0], inputSizes[1], + hDivBlockSizeInt, blockSize, + wDivBlockSizeInt, blockSize}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); + + SmallVector permuteDimsInt{0, 3, 5, 1, 2, 4}; + Value permutedInput; + if (failed(createTorchPermuteOp(binder, rewriter, binder.getLoc(), + reshapedInput, permuteDimsInt, + permutedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create Torch Permute op"); + + Value cMulBlockSizeSquare = rewriter.create( + binder.getLoc(), c, cstBlockSizeSquare); + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cMulBlockSizeSquare, hDivBlockSize, + wDivBlockSize}); + rewriter.replaceOpWithNewOp( + binder.op, resultType, permutedInput, reshapeSizesList); + return success(); + }); + patterns.onOp( + "Shrink", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value input; + float bias, lambd; + if (binder.tensorOperand(input) || + binder.f32FloatAttr(bias, "bias", 0.0) || + binder.f32FloatAttr(lambd, "lambd", 0.5) || + binder.tensorResultType(resultType)) { + return failure(); + } + + Torch::ValueTensorType inputType = + cast(input.getType()); + if (!isa(inputType.getDtype())) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: non-floating point dtype"); + + Torch::ValueTensorType comparisonResultType = + rewriter.getType( + ArrayRef(inputType.getSizes()), rewriter.getI1Type()); + + // The formula of this operator is: If x < -lambd, y = x + bias; If x > + // lambd, y = x - bias; Otherwise, y = 0. + // The implementation is based on the following algorithm: + // Shrink (input) => (output) + // { + // Lambd = Constant () + // LambdCast = CastLike (Lambd, input) + // Bias = Constant () + // BiasCast = CastLike (Bias, input) + // Zero = Constant () + // ZeroCast = CastLike (Zero, input) + // NegLmbda = Neg (LambdCast) + // InputLessThanNegLambda = Less (input, NegLmbda) + // InputAddBias = Add (input, BiasCast) + // InputSubBias = Sub (input, BiasCast) + // LambdaLessThanInput = Less (LambdCast, input) + // InputSubBiasOrZero = Where (LambdaLessThanInput, InputSubBias, + // ZeroCast) output = Where (InputLessThanNegLambda, InputAddBias, + // InputSubBiasOrZero) + // } + Value constLambd = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), lambd)); + Value constBias = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), bias)); + Value constZero = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + Value constOne = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); + Value constNegLambd = rewriter.create( + loc, rewriter.getFloatAttr(rewriter.getF64Type(), -lambd)); + + Value inputLTNegLambd = rewriter.create( + loc, comparisonResultType, input, constNegLambd); + Value inputPlusBias = rewriter.create( + loc, inputType, input, constBias, /*alpha=*/constOne); + Value inputSubBias = rewriter.create( + loc, inputType, input, constBias, /*alpha=*/constOne); + Value inputGTLambd = rewriter.create( + loc, comparisonResultType, input, constLambd); + + Value inputSubBiasOrZero = + rewriter.create( + loc, resultType, inputGTLambd, inputSubBias, constZero); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputLTNegLambd, inputPlusBias, + inputSubBiasOrZero); + return success(); + }); + patterns.onOp("SequenceAt", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(position, 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value index = rewriter.create( + binder.getLoc(), rewriter.getType(), + position); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, index); + return success(); + }); + patterns.onOp( + "SequenceEmpty", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + int64_t dtypeIntOnnx; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.tensorListResultType(resultType)) + return failure(); + + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, {}); + Value cstNone = rewriter.create(binder.getLoc()); + + Value self = rewriter.create( + binder.op->getLoc(), resultType.getContainedType(), shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, llvm::SmallVector{self}); + return success(); + }); + patterns.onOp( + "SequenceErase", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorListResultType(resultType)) + return failure(); + + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), inputSequence); + + Value cstNone = rewriter.create(binder.getLoc()); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the last tensor from the list has to be erased. + Value lengthMinusOne = rewriter.create( + binder.getLoc(), length, cstOne); + rewriter.replaceOpWithNewOp( + binder.op, resultType, inputSequence, /*start=*/cstNone, + /*end=*/lengthMinusOne, /*step=*/cstOne); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 1)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + // Handling negative position value. + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value isPositionNegative = rewriter.create( + binder.getLoc(), positionInt, cstZero); + isPositionNegative = rewriter.create( + binder.getLoc(), isPositionNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isPositionNegative, length); + positionInt = rewriter.create( + binder.getLoc(), positionInt, finalOffset); + + Value listBeforePosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, /*start=*/cstNone, + /*end=*/positionInt, /*step=*/cstOne); + Value positionPlusOne = rewriter.create( + binder.getLoc(), positionInt, cstOne); + Value listAfterPosition = rewriter.create( + binder.getLoc(), resultType, inputSequence, + /*start=*/positionPlusOne, + /*end=*/length, /*step=*/cstOne); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, listBeforePosition, listAfterPosition); + return success(); + }); + patterns.onOp( + "SequenceInsert", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position, insertValue; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(insertValue, 1) || + binder.tensorListResultType(resultType)) + return failure(); + + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the tensor has to be inserted at the end of the list. + Value length = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputSequence); + rewriter.replaceOpWithNewOp( + binder.op, inputSequence, /*idx=*/length, + /*el=*/insertValue); + return success(); + } + + if (binder.tensorOperandAtIndex(position, 2)) + return failure(); + + Value positionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), position); + rewriter.create(binder.getLoc(), inputSequence, + /*idx=*/positionInt, + /*el=*/insertValue); + rewriter.replaceOp(binder.op, inputSequence); + return success(); + }); + patterns.onOp( + "SequenceMap", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + llvm::SmallVector operands; + Torch::ListType resultType; + if (binder.tensorOperandsList(operands) || operands.size() == 0 || + binder.tensorListResultType(resultType)) { + return failure(); + } + + Region *bodyRegion; + if (binder.getRegionAtIndex(bodyRegion, 0)) { + return rewriter.notifyMatchFailure(binder.op, + "Failed getting Body Region"); + } + + // construct an empty list, append results through the loop + auto resultTensorType = + dyn_cast(resultType.getContainedType()); + Value shapeList = createConstantIntList(binder, rewriter, + resultTensorType.getSizes()); + Value cstNone = rewriter.create(binder.getLoc()); + Value self = rewriter.create( + binder.op->getLoc(), resultType.getContainedType(), shapeList, + /*dtype=*/cstNone, /*layout=*/cstNone, /*device=*/cstNone, + /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); + Value result = rewriter.create( + binder.getLoc(), resultType, llvm::SmallVector{self}); + + // create a for-like primLoopOp + // with the length of sequence as max iter_num + Value len = rewriter.create( + binder.getLoc(), rewriter.getType(), operands[0]); + auto cstTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + mlir::ImplicitLocOpBuilder b(binder.getLoc(), rewriter); + auto loop = + b.create(resultType, len, cstTrue, result); + rewriter.cloneRegionBefore(*bodyRegion, loop.getRegion(), + loop.getRegion().begin()); + + // primLoopOp loopBody expects torch.int as first arg + // remove inputs from the region and use it from outside + loop.getRegion().front().insertArgument(0U, resultType, + binder.getLoc()); + Value sequenceArg = loop.getRegion().front().getArgument(0); + loop.getRegion().front().insertArgument( + 0U, rewriter.getType(), binder.getLoc()); + Value indexArg = loop.getRegion().front().getArgument(0); + + // get sequence[i] (and addtionalInput[i]) in each iteration + rewriter.setInsertionPointToStart(&loop.getRegion().front()); + for (size_t i = 0; i < operands.size(); i++) { + Value argInput = loop.getRegion().front().getArgument(2); + if (isa(operands[i].getType())) { + auto tensorType = dyn_cast( + dyn_cast(operands[i].getType()) + .getContainedType()); + Value item = rewriter.create( + binder.getLoc(), tensorType, operands[i], indexArg); + argInput.replaceAllUsesWith(item); + } else { + argInput.replaceAllUsesWith(operands[i]); + } + loop.getRegion().eraseArgument(2); + } + + // replace terminator + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = loop.getRegion().front().getTerminator(); + rewriter.setInsertionPoint(terminator); + // update sequence input + auto terminatorOperands = terminator->getOperands(); + Value append = rewriter.create( + binder.getLoc(), resultType, sequenceArg, terminatorOperands[0]); + rewriter.replaceOpWithNewOp( + terminator, cstTrue, append); + + rewriter.replaceOp(binder.op, loop); + return success(); + }); + patterns.onOp( + "Upsample", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + std::string mode; + Value input, scales; + if (binder.tensorOperands(input, scales) || + binder.customOpNameStringAttr(mode, "mode", "nearest") || + binder.tensorResultType(resultType)) { + return failure(); + } + + if (mode != "nearest" && mode != "linear") + return rewriter.notifyMatchFailure( + binder.op, + R"(Expected valid interpolation mode: "nearest" | "linear")"); + + int64_t resultRank = resultType.getSizes().size(); + if (resultRank > 5) + return rewriter.notifyMatchFailure( + binder.op, "supports upto 3d upsampling only"); + + int64_t assumedForemostSpatialDim = 2; + Value scalesValueList = createScalarSublist( + binder.getLoc(), scales, assumedForemostSpatialDim, rewriter); + if (mode == "linear") { + if (resultRank == 4) + mode = "bilinear"; + if (resultRank == 5) + mode = "trilinear"; + } + Value modeStrValue = + rewriter.create(binder.getLoc(), mode); + Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + + rewriter + .replaceOpWithNewOp( + binder.op, resultType, input, /*size=*/cstNone, scalesValueList, + modeStrValue, + /* AnyTorchOptionalBoolType:$align_corners */ cstNone, + /* AnyTorchOptionalBoolType:$recompute_scale_factor */ cstNone, + /*Torch_BoolType:$antialias*/ cstFalse); + return success(); + }); + patterns.onOp( + "STFT", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // operands in order ->(signal, frameStep, window, frameLength*) + SmallVector operands; + int64_t onesided; + Torch::ValueTensorType resultType; + + if (binder.tensorOperandsList(operands) || + binder.s64IntegerAttr(onesided, "onesided", 1) || + binder.tensorResultType(resultType)) + return failure(); + + Value signal = operands[0]; + Value frameStep = operands[1]; + auto signalTy = cast(signal.getType()); + if (!signalTy || !signalTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected signal type having sizes"); + } + auto signalShape = signalTy.getSizes(); + // The infrastructure of ONNX and onnxruntime supports a rank-2. + // For reference: + // https://github.com/onnx/onnx/blob/060589cb81dfb081ed912c9e722b15fe1dbc1a14/onnx/defs/math/defs.cc#L3475-L3477 + if (signalShape.size() != 2 && signalShape.size() != 3) { + return rewriter.notifyMatchFailure(binder.op, + "signal has invalid shape."); + } + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } + auto resultShape = resultType.getSizes(); + if (resultShape.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "result has invalid shape."); + } + + // There are two possible cases for optional inputs frameLength and + // window, which are that either 4 operands will be passed with window + // being !torch.none, or three operands will be passed, with window + // present and frameLength absent. In the former case, we simply create + // a rectangular window consisting of ones, and in the latter, we set + // frameLength equal to the the inputShape[1] or windowShape[0] + // depending upon whether window was present or not. Note that it is + // possible that both window and frameLength can be none, which would + // mean that either only two operands were passed, or, in case of three + // operands, window was passed in as none, and frameLength was absent. + Value window = nullptr, frameLength = nullptr; + bool windowIsNone = true, frameLengthIsNone = true; + if (operands.size() == 3) { + window = operands[2]; + windowIsNone = isa(window.getType()); + } + if (operands.size() == 4) { + window = operands[2]; + frameLength = operands[3]; + windowIsNone = isa(window.getType()); + frameLengthIsNone = isa(frameLength.getType()); + } + + ArrayRef windowShape; + if (!windowIsNone) { + windowShape = + cast(window.getType()).getSizes(); + if (windowShape.size() != 1) { + return rewriter.notifyMatchFailure(binder.op, + "window has invalid shape."); + } + } + if (frameLengthIsNone) { + if (windowIsNone) { + frameLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(signalShape[1])); + } else { + frameLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + } + } + + Value frameLengthItem; + if (!frameLengthIsNone || windowIsNone) { + frameLengthItem = + getItemOp(binder, rewriter, frameLength); + } else { + frameLengthItem = frameLength; + } + Value frameStepItem = + getItemOp(binder, rewriter, frameStep); + + if (windowIsNone) { + auto onesResultTy = rewriter.getType( + ArrayRef({-1}), signalTy.getDtype()); + + Value none = rewriter.create(binder.getLoc()); + Value sizes = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + SmallVector{frameLengthItem}); + window = rewriter.create( + binder.getLoc(), onesResultTy, sizes, none, none, none, none); + } + + FailureOr complexDtype; + if (signalTy.getDtype().isBF16()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support for bfloat16 type is unimplemented."); + } + if (signalTy.getDtype().isF16()) { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexHalf); + } else if (signalTy.getDtype().isF32()) { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexFloat); + } else { + complexDtype = Torch::getTypeForScalarType( + binder.op->getContext(), + torch::torch_upstream::ScalarType::ComplexDouble); + } + + auto complexSignalTy = rewriter.getType( + ArrayRef({signalShape[0], signalShape[1]}), + complexDtype.value()); + + // The onnx STFT op always passes in a float input, and if the input + // is intended to be complex, its shape will be [batch][length][2], + // where [...][0] is the real component, and [...][1] is the complex + // component. This complex input has to be made torch compatible before + // being passed into torch.stft, so it is necessary to call + // AtenViewAsComplexOp. In case of real input, the shape of the signal + // will be [batch][length] or [batch][length][1], and therefore it will + // have to be squeezed at dim=2 in the latter case, before being passed + // into torch.stft. + if (signalShape.size() == 3) { + if (signalShape[2] == 2) { + signal = rewriter.create( + binder.getLoc(), complexSignalTy, signal); + } else { + Value two = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + auto newSignalTy = signalTy.getWithSizesAndDtype( + ArrayRef({signalShape[0], signalShape[1]}), + signalTy.getDtype()); + signal = rewriter.create( + binder.getLoc(), newSignalTy, signal, two); + } + } + + // In case the window is not given, we use frameLength + // as the length of the window. + Value windowLen; + if (!windowIsNone) { + windowLen = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + } else { + windowLen = frameLengthItem; + } + + Value falseVal = + rewriter.create(binder.getLoc(), false); + Value trueVal = + rewriter.create(binder.getLoc(), true); + auto stftTy = complexSignalTy.getWithSizesAndDtype( + ArrayRef({resultShape[0], resultShape[2], resultShape[1]}), + complexSignalTy.getDtype()); + + // After torch.stft is called and the result is stored into the value + // stft, there is one thing to note: The resultType for the onnx op + // will have shape [batch][num_frames][length][2], while the shape of + // stft will be [batch][length][num_frames]. Before the value is + // converted to real through torch.view_as_real, we must permute the + // shape of stft to match the shape of resultType. Also, it is + // immaterial whether torch.view_as_real is called after or before the + // permutation; both outputs will be equivalent. + Value stft = rewriter.create( + binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem, + windowLen, window, falseVal, onesided ? trueVal : falseVal, trueVal, + falseVal); + + auto permuteStftTy = complexSignalTy.getWithSizesAndDtype( + ArrayRef({resultShape[0], resultShape[1], resultShape[2]}), + complexSignalTy.getDtype()); + Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1}); + Value permutedStft = rewriter.create( + binder.getLoc(), permuteStftTy, stft, permuteDims); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, permutedStft); + return success(); + }); + patterns.onOp( + "ReverseSequence", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, sequenceLens; + int64_t batchAxis, timeAxis; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(sequenceLens, 1) || + binder.s64IntegerAttr(batchAxis, "batch_axis", 1) || + binder.s64IntegerAttr(timeAxis, "time_axis", 0) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTy = cast(input.getType()); + SmallVector inputShape(inputTy.getSizes()); + auto dtype = resultType.getDtype(); + + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value batchAxisVal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(batchAxis)); + Value timeAxisVal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(timeAxis)); + + SmallVector sliceShape(inputShape); + sliceShape[batchAxis] = 1; + auto sliceType = + rewriter.getType(sliceShape, dtype); + SmallVector flipShape(sliceShape); + flipShape[timeAxis] = Torch::kUnknownSize; + auto flipType = + rewriter.getType(flipShape, dtype); + auto scalarTensorType = rewriter.getType( + ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)); + + for (int i = 0; i < inputShape[batchAxis]; i++) { + // slice i iterating on batch axis + Value k = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value end = + rewriter.create(binder.getLoc(), k, cstOne); + Value sliceBatch = rewriter.create( + binder.getLoc(), sliceType, input, batchAxisVal, k, end, cstOne); + + // get sequence length and slice the reversing part + Value kTensor = rewriter.create( + binder.getLoc(), scalarTensorType, k); + Value sel = rewriter.create( + binder.getLoc(), scalarTensorType, sequenceLens, cstZero, + kTensor); + Value len = rewriter.create( + binder.getLoc(), rewriter.getType(), sel); + Value sliceTime = rewriter.create( + binder.getLoc(), flipType, sliceBatch, timeAxisVal, cstZero, len, + cstOne); + // flip the sliced reversing tensor + Value dims = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{timeAxisVal}); + Value flip = rewriter.create( + binder.getLoc(), flipType, sliceTime, dims); + + // embeds the reversed tensor to the input + Value embedTime = rewriter.create( + binder.getLoc(), sliceType, sliceBatch, flip, timeAxisVal, + /*start=*/cstZero, /*end=*/len, /*step=*/cstOne); + input = rewriter.create( + binder.getLoc(), resultType, input, embedTime, batchAxisVal, + /*start=*/k, /*end=*/end, /*step=*/cstOne); + } + + rewriter.replaceOp(binder.op, input); + return success(); + }); + patterns.onOp( + "ScatterND", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, indices, updates; + std::string reduction; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorOperandAtIndex(updates, 2) || + binder.tensorResultType(resultType)) + return failure(); + + // Previous to version 16 of ScatterND, reduction attribute was not + // supported. Setting it as "none" for unsupported versions. + if (binder.customOpNameStringAttr(reduction, "reduction", "none")) { + reduction = "none"; + } + + // Map onnx reduction type to torch reduction type. + if (reduction == "add") { + reduction = "sum"; + } else if (reduction == "mul") { + reduction = "prod"; + } else if (reduction == "max") { + reduction = "amax"; + } else if (reduction == "min") { + reduction = "amin"; + } else if (reduction != "none") { + return rewriter.notifyMatchFailure( + binder.op, "expects reduction to be one of add, mul, max, min, " + "none(default)"); + } + + Location loc = binder.getLoc(); + auto dataTy = dyn_cast(data.getType()); + auto indicesTy = dyn_cast(indices.getType()); + auto updatesTy = dyn_cast(updates.getType()); + if (!dataTy || !indicesTy || !updatesTy || !dataTy.hasSizes() || + !indicesTy.hasSizes() || !updatesTy.hasSizes()) + return failure(); + + // step 1. Get shapes and ranks of data, indices and updates. + // The last dimension of indices is expected to be static. + ArrayRef dataShape = dataTy.getSizes(); + int64_t dataRank = dataShape.size(); + ArrayRef updatesShape = updatesTy.getSizes(); + int64_t updatesRank = updatesShape.size(); + ArrayRef indicesShape = indicesTy.getSizes(); + int64_t indicesRank = indicesShape.size(); + int64_t indicesLastDim = indicesShape.back(); + // Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and + // updates tensor of rank q + r - indices_shape[-1] - 1, the output is + // produced by creating a copy of the input data, and then updating + // its value to values specified by updates at specific index positions + // specified by indices. Its output shape is the same as the shape of + // data. + // indices_shape[-1] must be static to have deterministic ranks. + if (dataRank < 1 || indicesRank < 1 || updatesRank < 1) + return rewriter.notifyMatchFailure( + binder.op, "expected data, indices and updates rank to be >= 1"); + if (indicesLastDim == Torch::kUnknownSize || indicesLastDim <= 0) + return rewriter.notifyMatchFailure( + binder.op, "expected last dimension of indices to be static and " + "greater than zero"); + + // step 2. Get dimension list of data. + SmallVector dataDims; + for (int64_t i = 0; i < dataRank; ++i) { + Value k = rewriter.create(loc, i); + Value dataDim = rewriter.create(loc, data, k); + dataDims.push_back(dataDim); + } + + // step 3. Get dimension list of indices. + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector indicesDimsMinusOne; + Value indicesFlattenDim = constOne; + for (int64_t i = 0; i < indicesRank - 1; ++i) { + Value k = rewriter.create(loc, i); + Value indicesDim = + rewriter.create(loc, indices, k); + indicesDimsMinusOne.push_back(indicesDim); + indicesFlattenDim = rewriter.create( + loc, indicesFlattenDim, indicesDim); + } + ArrayRef indicesShapeMinusOne = indicesShape.drop_back(); + + // Algorithm: We can not directly perform torch.scatter as it requires + // the ranks of data(`r`), indices(`q`) and updates to be same. + // So we will perform collapse and expand operations to match the + // ranks of data, indices and updates(making sure the semantic of the + // onnx.scatter_nd is preserved), then perform torch.scatter operation, + // later unflatten the scatter result to match onnx.scatter_nd output. + // For example, assuming + // indices is of shape (4, 5, 3, 2), data is (4, 10, 11, 7, 4) and + // updates is (4, 5, 3, 11, 7, 4). Firstly, modify indices to 1-D + // indexing as the torch.scatter op supports only single dimensional + // indexing(this algorithm would have been simpler if we can get a + // torch op that supports indexing at multiple dimensions + // simultaneously). 1-D indexed indices will be of shape (4, 5, 3, 1), + // now materialize it to `r-indices_shape[-1]` dimension of data i.e. + // reshaping it to the shape (4, 5, 3, 1, 1, 1). Next step is to + // flatten+expand the indices and flatten the data to (60, 11, 7, 4) and + // (40, 11, 7, 4) shapes respectively and then perform the torch.scatter + // operation. Post the scatter operation, unflatten the first dimension + // of result to (4, 10, 11, 7, 4) which is our required result. + + // step 4. Convert indices_shape[-1] dimensional indexing to 1D + // indexing. + Value sliceDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(indicesRank - 1)); + SmallVector indicesSliceShape(indicesShapeMinusOne); + indicesSliceShape.push_back(1); + auto indicesSliceTy = rewriter.getType( + indicesSliceShape, indicesTy.getOptionalDtype()); + + Value start = constZero; + Value updatedIndices; + for (int64_t i = 0; i < indicesLastDim; ++i) { + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(i + 1)); + Value indicesSlice = rewriter.create( + loc, indicesSliceTy, indices, sliceDim, start, end, + /*step=*/constOne); + start = end; + // Apply bounds checking on the indices slice. + auto boolTy = rewriter.getType( + indicesSliceShape, rewriter.getI1Type()); + Value lt = rewriter.create( + loc, boolTy, indicesSlice, constZero); + Value add = rewriter.create( + loc, indicesSliceTy, indicesSlice, dataDims[i], + /*alpha=*/constOne); + indicesSlice = rewriter.create( + loc, indicesSliceTy, lt, add, indicesSlice); + if (i == 0) { + updatedIndices = indicesSlice; + continue; + } + updatedIndices = rewriter.create( + loc, indicesSliceTy, indicesSlice, updatedIndices, dataDims[i]); + } + + // step 5. Compute all the required result types here. + SmallVector reshapeIndicesShape(indicesShapeMinusOne); + SmallVector reshapeIndicesDims(indicesDimsMinusOne); + // Determine the collapsed dim size of indices(index_shape[-1] is not + // part of collapsing as we already removed it by 1-D indexing). + SmallVector flattenIndicesShape; + auto indicesCt = 1; + for (int64_t i = 0; i < indicesRank - 1; ++i) { + if (indicesShape[i] == Torch::kUnknownSize) { + indicesCt = Torch::kUnknownSize; + break; + } + indicesCt *= indicesShape[i]; + } + flattenIndicesShape.push_back(indicesCt); + // Compute the shape of expand op. + SmallVector expandIndicesDims; + expandIndicesDims.push_back(indicesFlattenDim); + SmallVector expandIndicesShape; + expandIndicesShape.push_back(indicesCt); + // Determine the collapsed dim size of data. + SmallVector flattenDataShape; + auto dataCt = 1; + for (int64_t i = 0; i < indicesLastDim; ++i) { + if (dataShape[i] == Torch::kUnknownSize) { + dataCt = Torch::kUnknownSize; + break; + } + dataCt *= dataShape[i]; + } + flattenDataShape.push_back(dataCt); + // Determine the collapsed dim size of updates. + SmallVector flattenUpdatesShape; + auto updatesCt = 1; + for (int64_t i = 0; i < indicesRank - 1; ++i) { + if (updatesShape[i] == Torch::kUnknownSize) { + updatesCt = Torch::kUnknownSize; + break; + } + updatesCt *= updatesShape[i]; + } + flattenUpdatesShape.push_back(updatesCt); + flattenUpdatesShape.insert(flattenUpdatesShape.end(), + updatesShape.begin() + indicesRank - 1, + updatesShape.end()); + // Append `r-indices_shape[-1]` unit or data dims appropriately to all + // result types. + for (int64_t i = indicesLastDim; i < dataRank; ++i) { + reshapeIndicesShape.push_back(1); + flattenIndicesShape.push_back(1); + flattenDataShape.push_back(dataShape[i]); + expandIndicesShape.push_back(dataShape[i]); + reshapeIndicesDims.push_back(constOne); + expandIndicesDims.push_back(dataDims[i]); + } + + // step 6. Reshape 1-D indexed indices to match the rank of flattened + // data by inserting unit dimensions. + auto intListTy = rewriter.getType( + rewriter.getType()); + Value reshapeIndicesSizeList = + rewriter.create(loc, intListTy, + reshapeIndicesDims); + auto reshapeIndicesTy = rewriter.getType( + reshapeIndicesShape, indicesTy.getOptionalDtype()); + Value reshapedIndices = rewriter.create( + loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList); + + // step 7. Flatten `q-1` dimensions of the indices and updates. + auto flattenIndicesTy = rewriter.getType( + flattenIndicesShape, indicesTy.getOptionalDtype()); + auto flattenUpdatesTy = rewriter.getType( + flattenUpdatesShape, updatesTy.getOptionalDtype()); + Value flattenedIndices = reshapedIndices; + Value flattenedUpdates = updates; + if (indicesRank == 1) { + flattenedIndices = rewriter.create( + loc, flattenIndicesTy, reshapedIndices, constZero); + flattenedUpdates = rewriter.create( + loc, flattenUpdatesTy, updates, constZero); + } else if (indicesRank > 1) { + Value endDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(indicesRank - 2)); + flattenedIndices = rewriter.create( + loc, flattenIndicesTy, reshapedIndices, constZero, endDim); + flattenedUpdates = rewriter.create( + loc, flattenUpdatesTy, updates, constZero, endDim); + } + + // step 8. Expand `r-indices_shape[-1]` dims of flattened indices. + auto expandIndicesTy = rewriter.getType( + expandIndicesShape, indicesTy.getOptionalDtype()); + Value expandIndicesSizeList = + rewriter.create(loc, intListTy, + expandIndicesDims); + Value constFalse = rewriter.create( + loc, rewriter.getType(), + rewriter.getBoolAttr(false)); + Value expandedIndices = rewriter.create( + loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList, + /*implicit=*/constFalse); + + // step 9. Flatten indices_shape[-1] dimensions of data. + auto flattenDataTy = rewriter.getType( + flattenDataShape, dataTy.getOptionalDtype()); + Value endDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(indicesLastDim - 1)); + Value flattenedData = rewriter.create( + loc, flattenDataTy, data, constZero, endDim); + + // step 10. Now we have flattenedData, expandedIndices and + // flattenedUpdates of same rank to perform scatter operation. + auto scatterTy = rewriter.getType( + flattenDataShape, dataTy.getOptionalDtype()); + + Value scatter; + if (reduction == "none") { + scatter = rewriter.create( + loc, scatterTy, flattenedData, /*axis=*/constZero, + expandedIndices, flattenedUpdates); + } else { + Value cstReduction = + rewriter.create(loc, reduction); + Value constTrue = rewriter.create( + loc, rewriter.getType(), + rewriter.getBoolAttr(true)); + scatter = rewriter.create( + loc, scatterTy, flattenedData, /*axis=*/constZero, + expandedIndices, flattenedUpdates, cstReduction, + /*include_self=*/constTrue); + } + + // step 11. Unflatten the collapsed data dims of scatter result. + if (indicesLastDim == 1) { + rewriter.replaceOp(binder.op, scatter); + return success(); + } + Value unflattenSizeList = rewriter.create( + loc, intListTy, dataDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, scatter, constZero, unflattenSizeList); + return success(); + }); + // split to sequence + // Arguments: + // - input: the tensor to split + // -Split(optional): Length of each output + // Attributes: + // - axis: the axis along which to split the input + // - keepdims: to keep the split dimension or not. Ignored when 'split' is + // specified Outputs: + // - outputs: sequence of tensor + // + + patterns.onOp( + "SplitToSequence", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value self; + Value split; + int64_t axis; + int64_t keepdims; + Torch::ListType resultType; + + if (binder.op->getNumOperands() == 1) + return rewriter.notifyMatchFailure( + binder.op, "No of operands should be two.Keepdims attribute is " + "not yet implemented"); + + if (binder.tensorOperandAtIndex(self, 0) || + binder.tensorListResultType(resultType) || + binder.s64IntegerAttr(keepdims, "keepdims", 1) || + binder.tensorOperandAtIndex(split, 1) || + binder.s64IntegerAttr(axis, "axis", 0)) + return rewriter.notifyMatchFailure( + binder.op, + "Not converting to AtenSplitToSequenceOp due to inputs "); + + Value axisValue = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axis)); + auto splitTy = cast(split.getType()); + + if (!splitTy || !splitTy.hasSizes()) + return failure(); + + auto splitSizes = splitTy.getSizes(); + unsigned splitDim = splitTy.getSizes().size(); + + if (splitDim > 1) + return rewriter.notifyMatchFailure( + binder.op, "Split should be scalar or 1-D Tensor "); + + if (splitDim == 1) { + if (splitSizes[0] == Torch::kUnknownSize) { + return rewriter.notifyMatchFailure( + binder.op, "Dynamic shapes for Split is not yet supported"); + } else if (splitSizes[0] <= + 1) { // dealing with 1/0 element in 1-D tensor + Value splitInt = rewriter.create( + binder.getLoc(), rewriter.getType(), split); + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, splitInt, axisValue); + return success(); + } else { + // Handling multiple elment in split + Value shapeList = + createConstantIntList(binder, rewriter, splitSizes); + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, shapeList, axisValue); + return success(); + } + } else if (splitDim == 0) { // Handle 0-D tensor + Value splitInt = rewriter.create( + binder.getLoc(), rewriter.getType(), split); + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, splitInt, axisValue); + return success(); + } else { + return rewriter.notifyMatchFailure( + binder.op, "Handling of this kind of inputs is not there"); + } + }); + patterns.onOp( + "Unique", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value input; + int64_t axis, sorted; + SmallVector resultTypes; + + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(sorted, "sorted", 1) || + binder.tensorResultTypes(resultTypes)) + return failure(); + + Value zero = rewriter.create(binder.getLoc(), 0); + + auto inputTy = cast(input.getType()); + if (!inputTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type to have sizes"); + } + auto inputShape = inputTy.getSizes(); + int64_t inputDim = static_cast(inputShape.size()); + + Value axisVal; + SmallVector outputTensorSizes(inputDim); + bool axisWasNone; + if (!binder.optionalS64IntegerAttr(axis, "axis")) { + if (axis < -1 * inputDim || axis > inputDim - 1) + return rewriter.notifyMatchFailure(binder.op, + "invalid value for axis"); + axisVal = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + axisWasNone = false; + } else { + axisVal = zero; + axisWasNone = true; + } + + Value sortedVal = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(sorted)); + Value trueVal = + rewriter.create(binder.getLoc(), true); + + // The shape of inverse_indices is the same as input shape, but + // resulTypes[2] must be used to avoid live value after conversion. + Torch::ValueTensorType outputTy; + outputTy = cast(resultTypes[0]); + Torch::ValueTensorType countsTy = + cast(resultTypes[3]); + Torch::ValueTensorType inverseTy = + cast(resultTypes[2]); + + if (axisWasNone) { + int64_t inputNumel = 1; + for (auto elem : inputShape) { + if (elem == Torch::kUnknownSize) { + return rewriter.notifyMatchFailure( + binder.op, + "Expected all sizes in input shape to be statically known"); + } + inputNumel *= elem; + } + auto flattenResultTy = rewriter.getType( + ArrayRef({inputNumel}), inputTy.getDtype()); + Value negativeOne = + rewriter.create(binder.getLoc(), -1); + input = rewriter.create( + binder.getLoc(), flattenResultTy, input, zero, negativeOne); + } + + Torch::AtenUniqueDimOp intermResults = + rewriter.create( + binder.getLoc(), outputTy, inverseTy, countsTy, input, axisVal, + sortedVal, trueVal, trueVal); + + SmallVector uniqueResults = intermResults.getResults(); + + // Calculate the indices where each of the unique elements first + // appeared in the original input tensor. Also, the counts tensor and + // the indices tensor have the same Dtype, int64, so reuse that here. + auto arangeResultType = rewriter.getType( + ArrayRef({inputShape[0]}), countsTy.getOptionalDtype()); + + Value inputDimZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[0])); + Value int64Type = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(4)); + Value noneVal = rewriter.create(binder.getLoc()); + + Value perm = rewriter.create( + binder.getLoc(), arangeResultType, inputDimZero, + /*dtype=*/int64Type, + /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); + + // Inverse has the same shape as input, but the dtype is not the same. + Value flipDims = createConstantIntList(binder, rewriter, {0}); + Value inverse = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(inputShape, countsTy.getDtype()), + uniqueResults[1], flipDims); + perm = rewriter.create( + binder.getLoc(), cast(perm.getType()), perm, + flipDims); + + auto newInverseTy = rewriter.getType( + ArrayRef({outputTy.getSizes()[0]}), countsTy.getDtype()); + Value newInverseSize = + createConstantIntList(binder, rewriter, {outputTy.getSizes()[0]}); + Value newInverse = rewriter.create( + binder.getLoc(), newInverseTy, inverse, newInverseSize, + /*dtype=*/int64Type, /*layout=*/noneVal, /*device=*/noneVal, + /*pin_memory=*/noneVal); + + Value firstOccurIndices = rewriter.create( + binder.getLoc(), resultTypes[1], newInverse, zero, inverse, perm); + + rewriter.replaceOp(binder.op, {uniqueResults[0], firstOccurIndices, + uniqueResults[1], uniqueResults[2]}); + return success(); + }); + patterns.onOp( + "TfIdfVectorizer", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + llvm::SmallVector ngram_counts; + llvm::SmallVector ngram_indexes; + llvm::SmallVector pool_int64s; + llvm::SmallVector weights; + std::string mode; + int64_t min_gram_length; + int64_t max_gram_length; + int64_t max_skip_count; + Value input; + Torch::ValueTensorType resultType; + + if (binder.s64IntegerArrayAttr(ngram_counts, "ngram_counts", {}) || + binder.s64IntegerArrayAttr(ngram_indexes, "ngram_indexes", {}) || + binder.s64IntegerArrayAttr(pool_int64s, "pool_int64s", {}) || + binder.customOpNameStringAttr(mode, "mode", "") || + binder.s64IntegerAttr(min_gram_length, "min_gram_length", 0) || + binder.s64IntegerAttr(max_gram_length, "max_gram_length", 0) || + binder.s64IntegerAttr(max_skip_count, "max_skip_count", 0) || + binder.tensorOperand(input) || binder.tensorResultType(resultType)) + return failure(); + + llvm::SmallVector defaultWeights(ngram_indexes.size(), 1.0f); + if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights)) + return failure(); + + if (pool_int64s.size() == 0) + return rewriter.notifyMatchFailure( + binder.op, "pool_int64s empty, only integers supported"); + auto inputType = dyn_cast(input.getType()); + auto inputSizes = + dyn_cast(input.getType()).getSizes(); + SmallVector inputShape(inputSizes); + bool is_2d = (inputShape.size() > 1) ? true : false; + if (is_2d && inputShape[0] == ShapedType::kDynamic) + return rewriter.notifyMatchFailure( + binder.op, "input batch dimension cannot be dynamic"); + int batch_size = (is_2d) ? inputShape[0] : 1; + + Value none = rewriter.create(binder.getLoc()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstFalse = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(false)); + + auto intType = rewriter.getType(); + Value loopConditionTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + Type loopIndexType = intType; + // create a zero tensor for output + SmallVector resultShape(resultType.getSizes()); + int64_t rank = resultShape.size(); + SmallVector zerosShapeValues; + for (int j = 0; j < rank; j++) { + Value dimSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(resultShape[j])); + zerosShapeValues.push_back(dimSize); + } + Value zerosShapeList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + zerosShapeValues); + Value output = rewriter.create( + binder.getLoc(), resultType, zerosShapeList, none, none, none, + none); + + Value batchSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(batch_size)); + auto batchLoop = rewriter.create( + binder.getLoc(), TypeRange({output.getType()}), batchSize, + loopConditionTrue, ValueRange({output})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *batchLoopBody = rewriter.createBlock( + &batchLoop.getRegion(), batchLoop.getRegion().begin(), + TypeRange({loopIndexType, output.getType()}), + {binder.getLoc(), binder.getLoc()}); + Value batchValue = batchLoopBody->getArgument(0); + Value output = batchLoopBody->getArgument(1); + Value outputForBatch = output; + Value inputSequence = input; + if (is_2d) { + // get input sequence from input (ex: [[0,1],[2,3]] -> [[0,1]] -> + // [0,1]) + SmallVector inputSequenceShape; + inputSequenceShape.push_back(1); + inputSequenceShape.push_back(inputShape[1]); + auto inputSequenceType = rewriter.getType( + inputSequenceShape, inputType.getOptionalDtype()); + Value batchPlusOne = rewriter.create( + binder.getLoc(), batchValue, one); + inputSequence = rewriter.create( + binder.getLoc(), inputSequenceType, input, /*dim=*/zero, + batchValue, batchPlusOne, one); + inputSequence = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{inputShape[1]}, + inputType.getOptionalDtype()), + inputSequence, zero); + + SmallVector outputForBatchShape; + outputForBatchShape.push_back(1); + outputForBatchShape.push_back(resultShape[1]); + auto outputForBatchType = rewriter.getType( + outputForBatchShape, resultType.getOptionalDtype()); + outputForBatch = rewriter.create( + binder.getLoc(), outputForBatchType, output, + /*dim=*/zero, batchValue, batchPlusOne, one); + outputForBatch = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), + ArrayRef{resultShape[1]}, + resultType.getOptionalDtype()), + outputForBatch, zero); + } + // ngram_counts[j] records the starting position of ngrams within the + // pool_int64's of length j+1. The loop below is iterating through the + // different n-gram sizes + // ngram_i keeps track of which ngram we are looking at in the pool. + // The frequency of this ngram will be stored in the output tensor at + // the position ngram_indexes[ngram_i] + int ngram_i = 0; + for (int j = 0; j < (int)ngram_counts.size(); j++) { + int ngram_length = j + 1; + int start_idx = ngram_counts[j]; + int end_idx = (j + 1) < (int)ngram_counts.size() + ? ngram_counts[j + 1] + : pool_int64s.size(); + if (j + 1 < min_gram_length || j + 1 > max_gram_length) { + // progress the ngram counter for the skipped (j+1)grams + ngram_i += (end_idx - start_idx) / ngram_length; + continue; + } + + Value ngramLength = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(ngram_length)); + for (int start = start_idx; start < end_idx; + start += ngram_length, ngram_i++) { + Value count = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + // for 1-grams, there is no skipping (skip = gap between + // consecutive values in the n-gram pulled from the input + // sequence), so we default to skip_count_bound = 1 in that case + // to avoid repeating the same count multiple times. + int skip_count_bound = + (ngram_length == 1) ? 1 : (max_skip_count + 1); + Value skipCountBound = rewriter.create( + binder.getLoc(), intType, + rewriter.getI64IntegerAttr(skip_count_bound)); + // given a n-gram to search for, and the input sequence to search + // in, we need to count how many times that n-gram appears in the + // input for each skip between 0 and max_skip_count (inclusive). + auto skipLoop = rewriter.create( + binder.getLoc(), TypeRange({count.getType()}), skipCountBound, + loopConditionTrue, ValueRange({count})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *skipLoopBody = rewriter.createBlock( + &skipLoop.getRegion(), skipLoop.getRegion().begin(), + TypeRange({loopIndexType, count.getType()}), + {binder.getLoc(), binder.getLoc()}); + Value skipCount = skipLoopBody->getArgument(0); + Value skipCountPlusOne = rewriter.create( + binder.getLoc(), skipCount, one); + count = skipLoopBody->getArgument(1); + + // max_start_index = + // inputSizes.back() - ((ngram_length - 1) * (skip_count + 1)); + // the index one higher than the last possible start index + // without the input ngram going out of bounds + Value seqLen = rewriter.create( + binder.getLoc(), intType, + rewriter.getI64IntegerAttr(inputSizes.back())); + Value ngramLengthMinusOne = + rewriter.create(binder.getLoc(), + ngramLength, one); + Value ngramSkipLength = rewriter.create( + binder.getLoc(), ngramLengthMinusOne, skipCountPlusOne); + Value maxStartIndex = rewriter.create( + binder.getLoc(), seqLen, ngramSkipLength); + // This loop will extract each n-gram with the given skip_count + // from the input sequence from start input index, and increment + // the count if the n-gram matches the one gotten from the + // pool_int64s + auto countLoop = rewriter.create( + binder.getLoc(), TypeRange({count.getType()}), + maxStartIndex, loopConditionTrue, ValueRange({count})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *countLoopBody = rewriter.createBlock( + &countLoop.getRegion(), countLoop.getRegion().begin(), + TypeRange({loopIndexType, count.getType()}), + {binder.getLoc(), binder.getLoc()}); + + Value startInputIdx = countLoopBody->getArgument(0); + count = countLoopBody->getArgument(1); + + // extract input ngram and compare to pool ngram + Torch::BaseTensorType inputSequenceType = + cast(inputSequence.getType()); + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = + inputSequenceType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), + inputSequenceType.getOptionalDtype()); + Value foundNgram = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + for (int i = 0; i < ngram_length; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + i)); + selectIndex = rewriter.create( + binder.getLoc(), selectIndex, skipCountPlusOne); + selectIndex = rewriter.create( + binder.getLoc(), selectIndex, startInputIdx); + Value inputExtract = + rewriter.create( + binder.getLoc(), selectResultType, inputSequence, + zero, selectIndex); + Value inputNgram_i = rewriter.create( + binder.getLoc(), rewriter.getType(), + inputExtract); + + Value poolNgram_i = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(pool_int64s[start + i])); + Value isEqual = rewriter.create( + binder.getLoc(), inputNgram_i, poolNgram_i); + isEqual = rewriter.create( + binder.getLoc(), isEqual); + foundNgram = rewriter.create( + binder.getLoc(), isEqual, foundNgram); + } + + count = rewriter.create( + binder.getLoc(), count, foundNgram); + rewriter.create( + binder.getLoc(), loopConditionTrue, ValueRange({count})); + } + count = countLoop.getResult(0); + rewriter.create( + binder.getLoc(), loopConditionTrue, ValueRange({count})); + } + count = skipLoop.getResult(0); + Value countFloat = rewriter.create( + binder.getLoc(), count); + if (mode == "IDF" || mode == "TFIDF") { + // both IDF and TFIDF modes use weights + float weight = weights[ngram_i]; + Value constWeight = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(weight)); + + // TFIDF + Value multiplier = countFloat; + if (mode == "IDF") { + // All the counts larger than 1 would be truncated to 1 + // and the i-th element in weights would be used to scale + // (by multiplication) the count of the i-th n-gram in pool. + + Value intCount = rewriter.create( + binder.getLoc(), count); + // compare intCount > 0 + Value gtZeroCount = rewriter.create( + binder.getLoc(), intCount, zero); + gtZeroCount = rewriter.create( + binder.getLoc(), gtZeroCount); + Value gtZeroCountFloat = + rewriter.create(binder.getLoc(), + gtZeroCount); + multiplier = gtZeroCountFloat; + } + countFloat = rewriter.create( + binder.getLoc(), multiplier, constWeight); + } + Value dataList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{countFloat}); + Value cstDtype = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); + SmallVector countShape{1}; + auto countType = rewriter.getType( + countShape, resultType.getOptionalDtype()); + Value countTensor = rewriter.create( + binder.getLoc(), countType, dataList, /*dtype=*/cstDtype, + /*layout=*/none, /*requires_grad=*/cstFalse); + + Value insertStart = rewriter.create( + binder.getLoc(), + rewriter.getI64IntegerAttr(ngram_indexes[ngram_i])); + Value insertEnd = rewriter.create( + binder.getLoc(), insertStart, one); + outputForBatch = rewriter.create( + binder.getLoc(), outputForBatch.getType(), outputForBatch, + countTensor, + /*dim=*/zero, insertStart, insertEnd, /*step=*/one); + } // start + } + if (is_2d) { + Value batchPlusOne = rewriter.create( + binder.getLoc(), batchValue, one); + outputForBatch = rewriter.create( + binder.getLoc(), + rewriter.getType( + llvm::SmallVector{1, resultShape[1]}, + resultType.getDtype()), + outputForBatch, zero); + output = rewriter.create( + binder.getLoc(), resultType, output, outputForBatch, + /*dim=*/zero, batchValue, batchPlusOne, /*step=*/one); + } else { + output = outputForBatch; + } + rewriter.create( + binder.getLoc(), loopConditionTrue, ValueRange({output})); + } + output = batchLoop.getResult(0); + rewriter.replaceOp(binder.op, output); + return success(); + }); + patterns.onOp( + "Scan", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + SmallVector operands; + int64_t numScanInputs; + if (binder.tensorOperandsList(operands) || operands.size() == 0 || + binder.s64IntegerAttr(numScanInputs, "num_scan_inputs")) { + return rewriter.notifyMatchFailure(binder.op, + "Failed to get required inputs"); + } + SmallVector resultTypes; + if (binder.tensorResultTypes(resultTypes)) { + return rewriter.notifyMatchFailure(binder.op, + "result type bind failure"); + } + Region *loopBodyIn; + if (binder.getRegionAtIndex(loopBodyIn, 0)) { + return rewriter.notifyMatchFailure(binder.op, + "Failed getting LoopBody Region"); + } + + int64_t numInits = operands.size() - numScanInputs; + SmallVector initVals(operands.begin(), + operands.begin() + numInits); + SmallVector scanInputs(operands.begin() + numInits, + operands.end()); + if (scanInputs.size() < 1) { + return rewriter.notifyMatchFailure(binder.op, + "Expects at least one scan input"); + } + + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector scanOutTypes; + for (unsigned i = numInits; i < resultTypes.size(); i++) { + auto scanOutTy = cast(resultTypes[i]); + Value sizeList = + createConstantIntList(binder, rewriter, scanOutTy.getSizes()); + initVals.push_back(Torch::createInitTensor(rewriter, loc, scanOutTy, + constZero, sizeList)); + scanOutTypes.push_back(resultTypes[i]); + } + // Create torch.prim.Loop op. + Value maxTripCount = rewriter.create( + loc, scanInputs[0], constZero); + auto constBoolTrue = rewriter.create( + binder.getLoc(), rewriter.getBoolAttr(true)); + auto primLoop = rewriter.create( + loc, resultTypes, maxTripCount, constBoolTrue, initVals); + rewriter.cloneRegionBefore(*loopBodyIn, primLoop.getRegion(), + primLoop.getRegion().begin()); + + // Insert index var as torch.int argument in the loop body, as + // the primLoopOp loopBody expects torch.int as first argument. + primLoop.getRegion().insertArgument( + 0u, rewriter.getType(), loc); + auto loopInd = primLoop.getRegion().getArgument(0); + + // The block arguments of onnx.scan needs to be replaced with + // slice of scan inputs. + rewriter.setInsertionPointToStart(&primLoop.getRegion().front()); + for (unsigned i = 0; i < numScanInputs; i++) { + auto loopBlockArg = + primLoop.getRegion().getArgument(numInits + 1 + i); + Value extract = rewriter.create( + loc, loopBlockArg.getType(), scanInputs[i], constZero, loopInd); + loopBlockArg.replaceAllUsesWith(extract); + } + primLoop.getRegion().front().eraseArguments(numInits + 1, + /*count=*/numScanInputs); + + // Collect the output slices to form scan outputs and replace the + // terminator. + SmallVector locs(scanOutTypes.size(), loc); + primLoop.getRegion().front().addArguments(scanOutTypes, locs); + + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = primLoop.getRegion().front().getTerminator(); + auto terminatorOperands = terminator->getOperands(); + SmallVector resTerminatorOperands( + terminatorOperands.begin(), terminatorOperands.begin() + numInits); + SmallVector scanOutSlices(terminatorOperands.begin() + numInits, + terminatorOperands.end()); + rewriter.setInsertionPoint(terminator); + for (unsigned i = 0; i < scanOutSlices.size(); i++) { + Value self = BlockArgument::Value( + primLoop.getRegion().getArgument(numInits + 1 + i)); + FailureOr src = Torch::unsqueezeTensor( + rewriter, binder.op, scanOutSlices[i], constZero); + if (failed(src)) + return failure(); + Value scanOut = rewriter.create( + loc, scanOutTypes[i], self, src.value(), constZero, + /*start=*/loopInd, + /*end=*/loopInd, constOne); + resTerminatorOperands.push_back(scanOut); + } + + Value terminatorCond = constBoolTrue; + rewriter.replaceOpWithNewOp( + terminator, terminatorCond, resTerminatorOperands); + rewriter.replaceOp(binder.op, primLoop); return success(); }); } diff --git a/lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp b/lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp deleted file mode 100644 index 4c2ad051e0be..000000000000 --- a/lib/Conversion/TorchOnnxToTorch/OnnxLstmExpander.cpp +++ /dev/null @@ -1,514 +0,0 @@ -#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" -#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" - -using namespace mlir; -using namespace mlir::torch::Torch; -namespace mlir::torch::onnx_c { - -Value createActivationByName(ImplicitLocOpBuilder &b, StringRef name, - Value input) { - if (name == "Sigmoid") - return b.create(input.getType(), input); - if (name == "Tanh") - return b.create(input.getType(), input); - if (name == "Relu") - return b.create(input.getType(), input); - llvm_unreachable("Unsupported activation function"); -} - -// @struct LstmWeights -// @brief A structure to hold LSTM weights. -// -// Each W_ weight matrix should have shape [hidden_size, input_size]. -// Each R_ weight matrix should have shape [hidden_size, hidden_size]. -// Each bias vector should have shape [4 * hidden_size]. -struct LstmWeights { - Value W_i, W_o, W_f, W_c; - Value R_i, R_o, R_f, R_c; - Value Wb_i, Wb_o, Wb_f, Wb_c; - Value Rb_i, Rb_o, Rb_f, Rb_c; -}; -struct LstmActivations { - std::string f; - std::string g; - std::string h; -}; - -struct LstmCellState { - Value H; - Value C; -}; -// This function represents a Long Short-Term Memory (LSTM) cell operation. -// -// @param b A builder for constructing operations. -// @param Xt The input sequence. It has a shape of [batch_size, input_size]. -// @param H_prev The previous hidden state. It has a shape of [batch_size, -// hidden_size]. -// @param C_prev The previous cell state. It has a shape of [batch_size, -// hidden_size]. -// @param weights The weights for the LSTM cell. See @ref LstmWeights for shapes -// @param activations The activation functions for the LSTM cell. Members f,g,h -// correspond to f,g,h in https://onnx.ai/onnx/operators/onnx__LSTM.html -// @return The state of the LSTM cell after the operation. -LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, - Value C_prev, LstmWeights weights, - LstmActivations activations) { - - auto intType = b.getType(); - auto hTy = cast(H_prev.getType()); - - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); - - // Apply linear/matmul for each gate separately - // names are consistent with ONNX LSTM documentation - Value i_x = b.create(hTy, Xt, weights.W_i, weights.Wb_i); - Value i_h = b.create(hTy, H_prev, weights.R_i, weights.Rb_i); - Value i = b.create(hTy, i_x, i_h, cstOne); - Value i_act = createActivationByName(b, activations.f, i); - - Value o_x = b.create(hTy, Xt, weights.W_o, weights.Wb_o); - Value o_h = b.create(hTy, H_prev, weights.R_o, weights.Rb_o); - Value o = b.create(hTy, o_x, o_h, cstOne); - Value o_act = createActivationByName(b, activations.f, o); - - Value f_x = b.create(hTy, Xt, weights.W_f, weights.Wb_f); - Value f_h = b.create(hTy, H_prev, weights.R_f, weights.Rb_f); - Value f = b.create(hTy, f_x, f_h, cstOne); - Value f_act = createActivationByName(b, activations.f, f); - - Value ct_x = b.create(hTy, Xt, weights.W_c, weights.Wb_c); - Value ct_h = b.create(hTy, H_prev, weights.R_c, weights.Rb_c); - Value ct = b.create(hTy, ct_x, ct_h, cstOne); - Value ct_act = createActivationByName(b, activations.g, ct); - - Value C_forget = b.create(hTy, f_act, C_prev); - Value C_input = b.create(hTy, i_act, ct_act); - - LstmCellState newCellState; - newCellState.C = b.create(hTy, C_forget, C_input, cstOne); - Value C_new_act = createActivationByName(b, activations.h, newCellState.C); - newCellState.H = b.create(hTy, o_act, C_new_act); - return newCellState; -} - -struct LstmLayerOutput { - Value Y; - Value Y_h; - Value Y_c; -}; - -// @brief This function implements the LSTM (Long Short-Term Memory) layer -// operation. -// -// The core computation is performed in a loop that iterates over the sequence -// length. In each iteration, it selects the corresponding input, computes the -// new hidden state and cell state using the lstm_cell function, and updates the -// output tensor. -// -// @return A struct containing the hidden state history, final hidden state, -// and final cell state. -LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, - Value initial_c, LstmWeights weights, - LstmActivations activations) { - - Location loc = b.getLoc(); - - auto xTy = cast(X.getType()); - auto hTy = cast(initial_h.getType()); - // these names are snake_case for consistency with onnx.LSTM documentation - int64_t seq_len = xTy.getSizes()[0]; - int64_t batch_size = xTy.getSizes()[1]; - int64_t input_size = xTy.getSizes()[2]; - int64_t hidden_size = hTy.getSizes()[1]; - - auto cTy = hTy; - - auto intType = b.getType(); - - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); - Value cstSeqLen = - b.create(intType, b.getI64IntegerAttr(seq_len)); - Value cstBatchSize = - b.create(intType, b.getI64IntegerAttr(batch_size)); - Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); - - auto yTy = b.getType( - SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); - - auto YShapeList = b.create( - b.getType(intType), - ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); - - int64_t hDtypeInt = - static_cast(getScalarTypeForType(hTy.getDtype())); - Value hDtypeIntVal = - b.create(loc, b.getI64IntegerAttr(hDtypeInt)); - - Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, - cstNone, cstNone, cstNone); - - // Create a for-like PrimLoopOp. - Value maxTripCount = - b.create(intType, b.getI64IntegerAttr(seq_len)); - Value loopConditionTrue = b.create(true); - - Type loopIndexType = intType; - auto loop = b.create( - TypeRange({yTy, hTy, cTy}), maxTripCount, loopConditionTrue, - ValueRange({Y_initial, initial_h, initial_c})); - { - OpBuilder::InsertionGuard guard(b); - Block *loopBody = - b.createBlock(&loop.getRegion(), loop.getRegion().begin(), - TypeRange({ - loopIndexType, - yTy, - hTy, - cTy, - }), - {loc, loc, loc, loc} // locs for the loop body arguments - ); - - Value loopIndex = loopBody->getArgument(0); - Value Y_prev = loopBody->getArgument(1); - Value H_prev = loopBody->getArgument(2); - Value C_prev = loopBody->getArgument(3); - - auto xTy = cast(X.getType()); - auto XtType = b.getType( - llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); - - Value Xt = b.create(XtType, X, cstZero, loopIndex); - - auto [H_new, C_new] = - lstm_cell(b, Xt, H_prev, C_prev, weights, activations); - - Type hTyUnsqueezed = b.getType( - llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); - Value H_new_unsqueezed = - b.create(hTyUnsqueezed, H_new, cstZero); - - auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); - Value Y_new = - b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, - loopIndex, loopIndexPlusOne, cstOne); - - b.create(loopConditionTrue, - ValueRange({Y_new, H_new, C_new})); - } - LstmLayerOutput output; - output.Y = loop.getResult(0); - output.Y_h = loop.getResult(1); - output.Y_c = loop.getResult(2); - return output; -} -// @brief Expands an ONNX LSTM operation into torch ops. -// -// This function primarily handles the binding of operands and slicing of the -// weight matrix. The majority of the lowering process is managed in the -// lstm_layer and lstm_cell. For the shapes and meanings of the inputs, refer to -// the ONNX LSTM documentation at: -// https://onnx.ai/onnx/operators/onnx__LSTM.html -// The variable names are also consistent with the aforementioned documentation. -// -// This is not e2e tested here but is verified to work numerically downstream in -// SHARK-TestSuite. -// -// TODO: include this test case when the test infrastructure stops initializing -// weights separately for the reference and tested layers. -// @code{.py} -// class LSTMModule(torch.nn.Module): -// def __init__(self): -// super().__init__() -// self.lstm = torch.nn.LSTM(10, 20, 1) -// @export -// @annotate_args([ -// None, -// ([5, 1, 10], torch.float32, True), -// ([1, 1, 20], torch.float32, True), -// ([1, 1, 20], torch.float32, True), -// ]) -// def forward(self, input, h0, c0): -// return self.lstm(input, (h0, c0)) -// -// @register_test_case(module_factory=LSTMModule) -// def LSTMModule_basic(module, tu: TestUtils): -// inputs = torch.zeros(5,1,10) -// h0 = torch.zeros(1,1,20) -// c0 = torch.zeros(1,1,20) -// -// output, (hn, cn) = module.forward(inputs, h0, c0) -// @endcode -// -// @param binder The OpBinder object used for binding operands. -LogicalResult OnnxLstmExpander(OpBinder binder, - ConversionPatternRewriter &rewriter) { - Location loc = binder.getLoc(); - mlir::ImplicitLocOpBuilder b(loc, rewriter); - - std::string direction; - - ValueTensorType yTy, Y_hType, Y_cType; - if (binder.tensorResultTypeAtIndex(yTy, 0) || - binder.tensorResultTypeAtIndex(Y_hType, 1) || - binder.tensorResultTypeAtIndex(Y_cType, 2)) { - return rewriter.notifyMatchFailure(binder.op, - "At least one outputs must be present"); - } - Value X; - if (binder.tensorOperandAtIndex(X, 0)) - return rewriter.notifyMatchFailure(binder.op, - "Missing required input tensor X"); - Value W; - if (binder.tensorOperandAtIndex(W, 1)) - return rewriter.notifyMatchFailure(binder.op, - "Missing required input tensor W"); - Value R; - if (binder.tensorOperandAtIndex(R, 2)) - return rewriter.notifyMatchFailure(binder.op, - "Missing required input tensor R"); - int64_t hidden_size; - if (binder.s64IntegerAttr(hidden_size, "hidden_size")) - return rewriter.notifyMatchFailure( - binder.op, "Missing required attribute hidden_size"); - - auto xTy = cast(X.getType()); - auto wTy = cast(W.getType()); - Value B; - if (binder.tensorOperandAtIndex(B, 3)) { - B = b.create(W.getType(), W); - } - - llvm::SmallVector activationsList; - if (binder.stringArrayAttr(activationsList, "activations")) - return rewriter.notifyMatchFailure( - binder.op, "Missing required attribute; activations"); - - LstmActivations activations; - activations.f = "Sigmoid"; - activations.g = "Tanh"; - activations.h = "Tanh"; - if (activationsList.size() == 3) { - activations.f = activationsList[0]; - activations.g = activationsList[1]; - activations.h = activationsList[2]; - } else if (activationsList.size() != 0) { - return rewriter.notifyMatchFailure( - binder.op, "activations must be empty have 3 elements, but " + - std::to_string(activationsList.size()) + - " are provided."); - } - - if (!binder.customOpNameStringAttr(direction, "direction", "forward") && - direction != "forward") - return rewriter.notifyMatchFailure(binder.op, - "Unsupported direction attribute value. " - "Only 'forward' is supported but '" + - direction + "' is provided."); - int64_t num_directions = 1 + (direction == "bidirectional"); - - auto XShape = xTy.getSizes(); - int64_t batch_size = XShape[1]; - int64_t input_size = XShape[2]; - if (num_directions != wTy.getSizes()[0]) - return rewriter.notifyMatchFailure( - binder.op, "num_directions (" + std::to_string(num_directions) + - ") does not match the first dimension of wTy (" + - std::to_string(wTy.getSizes()[0]) + ")"); - if (num_directions != 1) - return rewriter.notifyMatchFailure( - binder.op, "num_directions (" + std::to_string(num_directions) + - ") is not equal to 1"); - if (4 * hidden_size != wTy.getSizes()[1]) - return rewriter.notifyMatchFailure( - binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) + - ") does not match the second dimension of wTy (" + - std::to_string(wTy.getSizes()[1]) + ")"); - if (wTy.getSizes()[2] != input_size) - return rewriter.notifyMatchFailure( - binder.op, - "The third dimension of wTy (" + std::to_string(wTy.getSizes()[2]) + - ") does not match input_size (" + std::to_string(input_size) + ")"); - - /** - * @brief Splits the input tensor based on the provided direction. - * - * This function is used to split the LSTM parameters (W, R, B) into forward - * and backward directions. The input tensor is expected to have the forward - * and backward parameters concatenated along the 0th dimension. The function - * returns a tensor that contains the parameters for the specified direction. - * - * @param direction The direction to split out. 0 for forward, 1 for backward. - * @param input The input tensor to split. - * @return The split tensor for the specified direction. - */ - auto getDirection = [&](int64_t direction, Value input) { - auto inputType = cast(input.getType()); - - // drop 0th dimension - auto outputType = cast(inputType.getWithSizesAndDtype( - llvm::SmallVector{inputType.getSizes().drop_front()}, - inputType.getDtype())); - - auto intType = b.getType(); - Value selectDim = b.create(intType, b.getI64IntegerAttr(0)); - Value cstDirection = - b.create(intType, b.getI64IntegerAttr(direction)); - return b.create(outputType, input, selectDim, - cstDirection); - }; - - Value W_forward = getDirection(0, W); - Value R_forward = getDirection(0, R); - Value B_forward = getDirection(0, B); - - auto hTy = b.getType( - llvm::SmallVector{num_directions, batch_size, hidden_size}, - xTy.getDtype()); - - auto intType = b.getType(); - - Value cstNumDirections = - b.create(intType, b.getI64IntegerAttr(num_directions)); - Value cstBatchSize = - b.create(intType, b.getI64IntegerAttr(batch_size)); - Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); - - Value hShape = b.create( - b.getType(intType), - ValueRange({cstNumDirections, cstBatchSize, cstHiddenSize})); - - Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); - - Value initial_h; - if (binder.tensorOperandAtIndex(initial_h, 5)) { - initial_h = - b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); - } - Value initial_c; - if (binder.tensorOperandAtIndex(initial_c, 6)) { - initial_c = - b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); - } - - Value initial_h_forward = getDirection(0, initial_h); - Value initial_c_forward = getDirection(0, initial_c); - - if (num_directions != 1) { - return rewriter.notifyMatchFailure( - binder.op, "Unsupported num_directions. Only 1 is supported but " + - std::to_string(num_directions) + " is provided."); - // TODO: support bidirectional LSTM by doing both directions and replacing - // Unsqueeze with Stack - } - // Everything hereon is for the forward direction, with the direction - // dimention squeezed out. - - LstmWeights weights; // weights and biases - - auto intConst = [&](int64_t val) { - return b.create(intType, b.getI64IntegerAttr(val)); - }; - - // split B into Wb and Rb - Value inputWeightsEndIdx = intConst(4 * hidden_size); - Value recurrentWeightsStartIdx = inputWeightsEndIdx; - Value recurrentWeightsEndIdx = intConst(8 * hidden_size); - auto biasType = b.getType( - llvm::SmallVector{hidden_size * 4}, wTy.getDtype()); - Value Wb = b.create(biasType, - /*input=*/B_forward, - /*dim=*/cstZero, - /*start=*/cstZero, - /*end=*/inputWeightsEndIdx, - /*step=*/cstOne); - Value Rb = b.create(biasType, - /*input=*/B_forward, - /*dim=*/cstZero, - /*start=*/recurrentWeightsStartIdx, - /*end=*/recurrentWeightsEndIdx, - /*step=*/cstOne); - - // gate splitting - auto gateBiasType = b.getType( - llvm::SmallVector{hidden_size}, - cast(Wb.getType()).getDtype()); - auto gateWeightsTypeIH = b.getType( - llvm::SmallVector{hidden_size, input_size}, - cast(W_forward.getType()).getDtype()); - auto gateWeightsTypeHH = b.getType( - llvm::SmallVector{hidden_size, hidden_size}, - cast(R_forward.getType()).getDtype()); - - Value inputGateWeightsEndIdx = intConst(hidden_size); - Value outputGateWeightsEndIdx = intConst(2 * hidden_size); - Value forgetGateWeightsEndIdx = intConst(3 * hidden_size); - Value cellGateWeightsEndIdx = intConst(4 * hidden_size); - - auto sliceIOFC = [&](std::function slicerFunction) { - // slice into 4 components and return tuple - return std::make_tuple( - slicerFunction(cstZero, inputGateWeightsEndIdx), - slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx), - slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx), - slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx)); - }; - - auto sliceGateBias = [&](Value startIdx, Value endIdx) { - return b.create(gateBiasType, Wb, cstZero, startIdx, - endIdx, cstOne); - }; - std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) = - sliceIOFC(sliceGateBias); - - auto sliceGateBiasR = [&](Value startIdx, Value endIdx) { - return b.create(gateBiasType, Rb, cstZero, startIdx, - endIdx, cstOne); - }; - std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) = - sliceIOFC(sliceGateBiasR); - - auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) { - return b.create(gateWeightsTypeIH, W_forward, cstZero, - startIdx, endIdx, cstOne); - }; - std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) = - sliceIOFC(sliceGateWeightsIH); - - auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) { - return b.create(gateWeightsTypeHH, R_forward, cstZero, - startIdx, endIdx, cstOne); - }; - std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) = - sliceIOFC(sliceGateWeightsHH); - LstmLayerOutput lstmLayerOutput = lstm_layer( - b, X, initial_h_forward, initial_c_forward, weights, activations); - - auto Y_h_Y_c_unsqueezed_type = b.getType( - llvm::SmallVector{num_directions, batch_size, hidden_size}, - cast(lstmLayerOutput.Y_h.getType()).getDtype()); - Value Y_h_unsqueezed = b.create( - Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero); - Value Y_c_unsqueezed = b.create( - Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero); - - // unsqueeze num_directions dim1 of Y - // to create the onnx.LSTM output shape [seq_length, num_directions, - // batch_size, hidden_size] - Value Y_unsqueezed = - b.create(yTy, lstmLayerOutput.Y, cstOne); - - rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed, - Y_c_unsqueezed}); - return success(); -} -} // namespace mlir::torch::onnx_c diff --git a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp new file mode 100644 index 000000000000..317a5459ea38 --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp @@ -0,0 +1,1511 @@ +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::torch::Torch; + +namespace mlir::torch::onnx_c { + +/** + * @brief Splits the input tensor based on the provided direction. + * + * This function is used to split the LSTM parameters (W, R, B) into forward + * and backward directions. The input tensor is expected to have the forward + * and backward parameters concatenated along the 0th dimension. The function + * returns a tensor that contains the parameters for the specified direction. + * + * @param direction The direction to split out. 0 for forward, 1 for backward. + * @param input The input tensor to split. + * @return The split tensor for the specified direction. + */ +Value getDirection(ImplicitLocOpBuilder b, int64_t direction, Value input) { + auto inputType = cast(input.getType()); + auto outputType = cast(inputType.getWithSizesAndDtype( + llvm::SmallVector{inputType.getSizes().drop_front()}, + inputType.getDtype())); + auto intType = b.getType(); + Value selectDim = b.create(intType, b.getI64IntegerAttr(0)); + Value cstDirection = + b.create(intType, b.getI64IntegerAttr(direction)); + return b.create(outputType, input, selectDim, cstDirection); +} + +struct RnnWeights { + Value Wi; + Value Ri; + Value Wbi; + Value Rbi; +}; + +struct RnnActivations { + std::string f; +}; + +Value rnn_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, + RnnWeights weights, RnnActivations activations) { + auto hTy = cast(H_prev.getType()); + + auto intType = b.getType(); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + Value i_x = b.create(hTy, Xt, weights.Wi, weights.Wbi); + Value i_h = b.create(hTy, H_prev, weights.Ri, weights.Rbi); + Value i = b.create(hTy, i_x, i_h, cstOne); + + Value H_new = createActivationByName(b, activations.f, i); + return H_new; +} + +struct RnnLayerOutput { + Value Y; + Value Y_h; +}; + +RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, + RnnWeights weights, RnnActivations activations) { + Location loc = b.getLoc(); + + auto xTy = cast(X.getType()); + auto hTy = cast(initial_h.getType()); + int64_t seq_len = xTy.getSizes()[0]; + int64_t batch_size = xTy.getSizes()[1]; + int64_t input_size = xTy.getSizes()[2]; + int64_t hidden_size = hTy.getSizes()[1]; + + auto intType = b.getType(); + + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstSeqLen = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + + auto yTy = b.getType( + SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); + + auto YShapeList = b.create( + b.getType(intType), + ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); + + int64_t hDtypeInt = + static_cast(getScalarTypeForType(hTy.getDtype())); + Value hDtypeIntVal = + b.create(loc, b.getI64IntegerAttr(hDtypeInt)); + + Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); + + Value maxTripCount = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value loopConditionTrue = b.create(true); + + Type loopIndexType = intType; + auto loop = b.create(TypeRange({yTy, hTy}), maxTripCount, + loopConditionTrue, + ValueRange({Y_initial, initial_h})); + { + OpBuilder::InsertionGuard guard(b); + Block *loopBody = + b.createBlock(&loop.getRegion(), loop.getRegion().begin(), + TypeRange({ + loopIndexType, + yTy, + hTy, + }), + {loc, loc, loc} // locs for the loop body arguments + ); + + Value loopIndex = loopBody->getArgument(0); + Value Y_prev = loopBody->getArgument(1); + Value H_prev = loopBody->getArgument(2); + + auto xTy = cast(X.getType()); + auto XtType = b.getType( + llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); + + Value Xt = b.create(XtType, X, cstZero, loopIndex); + + Value H_new = rnn_cell(b, Xt, H_prev, weights, activations); + + Type hTyUnsqueezed = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); + Value H_new_unsqueezed = + b.create(hTyUnsqueezed, H_new, cstZero); + + auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + Value Y_new = + b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); + + b.create(loopConditionTrue, + ValueRange({Y_new, H_new})); + } + RnnLayerOutput output; + output.Y = loop.getResult(0); + output.Y_h = loop.getResult(1); + return output; +} + +static Value StaticTranspose(ImplicitLocOpBuilder b, Value value, int64_t dim0, + int64_t dim1) { + auto valueTy = cast(value.getType()); + + SmallVector valueShape(valueTy.getSizes()); + std::swap(valueShape[dim0], valueShape[dim1]); + valueTy = b.getType(valueShape, valueTy.getDtype()); + + auto intType = b.getType(); + Value dim0v = b.create(intType, b.getI64IntegerAttr(dim0)); + Value dim1v = b.create(intType, b.getI64IntegerAttr(dim1)); + + return b.create(valueTy, value, dim0v, dim1v); +} + +LogicalResult OnnxRnnExpander(OpBinder binder, + ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + + auto intType = b.getType(); + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + int64_t num_directions = Torch::kUnknownSize; + int64_t hidden_size = Torch::kUnknownSize; + + // Attributes + llvm::SmallVector activationsList; + RnnActivations activations; + activations.f = "Tanh"; + if (!binder.stringArrayAttr(activationsList, "activations") && + activationsList.size() > 0) { + if (activationsList.size() == 1) { + activations.f = activationsList[0]; + } else if (activationsList.size() == 2) { + return rewriter.notifyMatchFailure( + binder.op, "Bi-directional RNN is not yet supported, yet two " + "activation function names are provided"); + } else { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported number of activation functions: " + + std::to_string(activationsList.size()) + + " are provided."); + } + } + + std::string direction; + if (!binder.customOpNameStringAttr(direction, "direction", "forward") && + direction != "forward") + return rewriter.notifyMatchFailure(binder.op, + "Unsupported direction attribute value. " + "Only 'forward' is supported but '" + + direction + "' is provided."); + num_directions = (direction == "bidirectional") ? 2 : 1; + + // hidden_size is required according to the docs, + // but if we encounter a model that doesn't have it + // that we really want to just push through, consider + // deleting this check and making it infer the hidden size + if (binder.s64IntegerAttr(hidden_size, "hidden_size")) + return rewriter.notifyMatchFailure( + binder.op, "Missing required attribute hidden_size"); + + // Other attributes + int64_t layout; + if (binder.s64IntegerAttr(layout, "layout", 0)) + return rewriter.notifyMatchFailure(binder.op, + "Unsupported layout attribute type."); + + if (layout < 0 || layout > 1) + return rewriter.notifyMatchFailure(binder.op, + "Unsupported layout attribute value."); + + // Result types + ValueTensorType yTy, Y_hType; + auto hasResult0 = binder.tensorResultTypeAtIndex(yTy, 0); + auto hasResult1 = binder.tensorResultTypeAtIndex(Y_hType, 1); + + if (hasResult0 && hasResult1) { + return rewriter.notifyMatchFailure(binder.op, + "At least one output must be present"); + } + + // Inputs + Value X, W, R, B, initial_h; + if (binder.tensorOperandAtIndex(X, 0)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor X"); + if (binder.tensorOperandAtIndex(W, 1)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor W"); + if (binder.tensorOperandAtIndex(R, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor R"); + if (binder.tensorOperandAtIndex(B, 3)) { + // if no b found, set to null and create one later + B = nullptr; + } + if (binder.tensorOperandAtIndex(initial_h, 5)) { + // if no initial_h found, set to null and create one later + initial_h = nullptr; + } + + if (layout == 1) { + X = StaticTranspose(b, X, 0, 1); + if (initial_h) + initial_h = StaticTranspose(b, initial_h, 0, 1); + } + + // validation + auto xTy = cast(X.getType()); + auto wTy = cast(W.getType()); + auto rTy = cast(R.getType()); + auto wShape = wTy.getSizes(); + auto xShape = xTy.getSizes(); + auto rShape = rTy.getSizes(); + assert(wShape.size() == 3); + + int64_t seq_len = xShape[0]; + int64_t batch_size = xShape[1]; + int64_t x_input_size = xShape[2]; + + int64_t w_num_directions = wShape[0]; + int64_t w_hidden_size = wShape[1]; + int64_t w_input_size = wShape[2]; + + int64_t r_num_directions = rShape[0]; + if (rShape[1] != rShape[2]) + return rewriter.notifyMatchFailure( + binder.op, + "R tensor must be square, but got shape: " + std::to_string(rShape[1]) + + "x" + std::to_string(rShape[2])); + int64_t r_hidden_size = rShape[1]; + + // validate input size + if (x_input_size != w_input_size) { + return rewriter.notifyMatchFailure( + binder.op, "input_size inferred from shape of X (" + + std::to_string(x_input_size) + + ") does not match the input_size attribute value (" + + std::to_string(w_input_size) + ")"); + } + + // validate hidden size + if (w_hidden_size != Torch::kUnknownSize && hidden_size != w_hidden_size) { + return rewriter.notifyMatchFailure( + binder.op, "hidden_size inferred from shape of W (" + + std::to_string(w_hidden_size) + + ") does not match the hidden_size attribute value (" + + std::to_string(hidden_size) + ")"); + } + + if (r_hidden_size != Torch::kUnknownSize && hidden_size != r_hidden_size) { + return rewriter.notifyMatchFailure( + binder.op, "hidden_size inferred from shape of R (" + + std::to_string(r_hidden_size) + + ") does not match the hidden_size attribute value (" + + std::to_string(hidden_size) + ")"); + } + + // validate num directions + if (w_num_directions != Torch::kUnknownSize && + w_num_directions != num_directions) { + return rewriter.notifyMatchFailure( + binder.op, "num_directions from shape of W (" + + std::to_string(w_num_directions) + + ") does not match the direction attribute value (" + + direction + ")"); + } + + if (r_num_directions != Torch::kUnknownSize && + r_num_directions != num_directions) { + return rewriter.notifyMatchFailure( + binder.op, "num_directions from shape of R (" + + std::to_string(r_num_directions) + + ") does not match the direction attribute value (" + + direction + ")"); + } + + if (num_directions != 1) { + return rewriter.notifyMatchFailure( + binder.op, + "Unsupported num_directions. Only 1 is currently supported but " + + std::to_string(num_directions) + " is provided."); + } + + // Create B and initial_h if not provided, + // using same dtype as X + Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); + if (B == nullptr) { + SmallVector BShape = {num_directions, 2 * hidden_size}; + SmallVector BShapeListContents = { + b.create(intType, b.getI64IntegerAttr(num_directions)), + b.create(intType, b.getI64IntegerAttr(2 * hidden_size))}; + Value BShapeList = b.create( + b.getType(intType), BShapeListContents); + auto BType = b.getType(BShape, wTy.getDtype()); + B = b.create(BType, BShapeList, cstXDtype, cstNone, + cstNone, cstNone); + } + if (initial_h == nullptr) { + SmallVector initial_h_shape = {num_directions, batch_size, + hidden_size}; + SmallVector initial_h_shape_list_contents = { + b.create(intType, b.getI64IntegerAttr(num_directions)), + b.create(intType, b.getI64IntegerAttr(batch_size)), + b.create(intType, b.getI64IntegerAttr(hidden_size))}; + Value initial_h_shape_list = b.create( + b.getType(intType), initial_h_shape_list_contents); + auto initial_h_type = + b.getType(initial_h_shape, wTy.getDtype()); + initial_h = + b.create(initial_h_type, initial_h_shape_list, + cstXDtype, cstNone, cstNone, cstNone); + } + + Value W_forward = getDirection(b, 0, W); + Value R_forward = getDirection(b, 0, R); + Value B_forward = getDirection(b, 0, B); + Value initial_h_forward = getDirection(b, 0, initial_h); + + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + + RnnWeights weights; + weights.Wi = W_forward; + weights.Ri = R_forward; + weights.Wbi = b.create( + b.getType(llvm::SmallVector{hidden_size}, + wTy.getDtype()), + B_forward, cstZero, cstZero, cstHiddenSize, cstOne); + weights.Rbi = b.create( + b.getType(llvm::SmallVector{hidden_size}, + wTy.getDtype()), + B_forward, cstZero, cstHiddenSize, + b.create( + cstHiddenSize, + b.create(intType, b.getI64IntegerAttr(2))), + cstOne); + + RnnLayerOutput rnnLayerOutput = + rnn_layer(b, X, initial_h_forward, weights, activations); + + auto Y_h_unsqueezed_type = b.getType( + llvm::SmallVector{num_directions, batch_size, hidden_size}, + cast(rnnLayerOutput.Y_h.getType()).getDtype()); + Value Y_h_unsqueezed = b.create(Y_h_unsqueezed_type, + rnnLayerOutput.Y_h, cstZero); + + auto Y_unsqueezed_type = b.getType( + llvm::SmallVector{seq_len, num_directions, batch_size, + hidden_size}, + cast(rnnLayerOutput.Y_h.getType()).getDtype()); + Value Y_unsqueezed = + b.create(Y_unsqueezed_type, rnnLayerOutput.Y, cstOne); + + if (layout == 1) { + Y_h_unsqueezed = StaticTranspose(b, Y_h_unsqueezed, 0, 1); + Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 1, 2); + Y_unsqueezed = StaticTranspose(b, Y_unsqueezed, 0, 1); + } + + if (!yTy) + Y_unsqueezed = cstNone; + if (!Y_hType) + Y_h_unsqueezed = cstNone; + + rewriter.replaceOp(binder.op, {Y_unsqueezed, Y_h_unsqueezed}); + return success(); +} + +// @struct LstmWeights +// @brief A structure to hold LSTM weights. +// +// Each W_ weight matrix should have shape [hidden_size, input_size]. +// Each R_ weight matrix should have shape [hidden_size, hidden_size]. +// Each bias vector should have shape [4 * hidden_size]. +struct LstmWeights { + Value W_i, W_o, W_f, W_c; + Value R_i, R_o, R_f, R_c; + Value Wb_i, Wb_o, Wb_f, Wb_c; + Value Rb_i, Rb_o, Rb_f, Rb_c; +}; +struct LstmActivations { + std::string f; + std::string g; + std::string h; +}; + +struct LstmCellState { + Value H; + Value C; +}; +// This function represents a Long Short-Term Memory (LSTM) cell operation. +// +// @param b A builder for constructing operations. +// @param Xt The input sequence. It has a shape of [batch_size, input_size]. +// @param H_prev The previous hidden state. It has a shape of [batch_size, +// hidden_size]. +// @param C_prev The previous cell state. It has a shape of [batch_size, +// hidden_size]. +// @param weights The weights for the LSTM cell. See @ref LstmWeights for shapes +// @param activations The activation functions for the LSTM cell. Members f,g,h +// correspond to f,g,h in https://onnx.ai/onnx/operators/onnx__LSTM.html +// @return The state of the LSTM cell after the operation. +LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, + Value C_prev, LstmWeights weights, + LstmActivations activations) { + + auto intType = b.getType(); + auto hTy = cast(H_prev.getType()); + + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + // Apply linear/matmul for each gate separately + // names are consistent with ONNX LSTM documentation + Value i_x = b.create(hTy, Xt, weights.W_i, weights.Wb_i); + Value i_h = b.create(hTy, H_prev, weights.R_i, weights.Rb_i); + Value i = b.create(hTy, i_x, i_h, cstOne); + Value i_act = createActivationByName(b, activations.f, i); + + Value o_x = b.create(hTy, Xt, weights.W_o, weights.Wb_o); + Value o_h = b.create(hTy, H_prev, weights.R_o, weights.Rb_o); + Value o = b.create(hTy, o_x, o_h, cstOne); + Value o_act = createActivationByName(b, activations.f, o); + + Value f_x = b.create(hTy, Xt, weights.W_f, weights.Wb_f); + Value f_h = b.create(hTy, H_prev, weights.R_f, weights.Rb_f); + Value f = b.create(hTy, f_x, f_h, cstOne); + Value f_act = createActivationByName(b, activations.f, f); + + Value ct_x = b.create(hTy, Xt, weights.W_c, weights.Wb_c); + Value ct_h = b.create(hTy, H_prev, weights.R_c, weights.Rb_c); + Value ct = b.create(hTy, ct_x, ct_h, cstOne); + Value ct_act = createActivationByName(b, activations.g, ct); + + Value C_forget = b.create(hTy, f_act, C_prev); + Value C_input = b.create(hTy, i_act, ct_act); + + LstmCellState newCellState; + newCellState.C = b.create(hTy, C_forget, C_input, cstOne); + Value C_new_act = createActivationByName(b, activations.h, newCellState.C); + newCellState.H = b.create(hTy, o_act, C_new_act); + return newCellState; +} + +struct LstmLayerOutput { + Value Y; + Value Y_h; + Value Y_c; +}; + +// @brief This function implements the LSTM (Long Short-Term Memory) layer +// operation. +// +// The core computation is performed in a loop that iterates over the sequence +// length. In each iteration, it selects the corresponding input, computes the +// new hidden state and cell state using the lstm_cell function, and updates the +// output tensor. +// +// @return A struct containing the hidden state history, final hidden state, +// and final cell state. +LstmLayerOutput lstm_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, + Value initial_c, LstmWeights weights, + LstmActivations activations) { + + Location loc = b.getLoc(); + + auto xTy = cast(X.getType()); + auto hTy = cast(initial_h.getType()); + // these names are snake_case for consistency with onnx.LSTM documentation + int64_t seq_len = xTy.getSizes()[0]; + int64_t batch_size = xTy.getSizes()[1]; + int64_t input_size = xTy.getSizes()[2]; + int64_t hidden_size = hTy.getSizes()[1]; + + auto cTy = hTy; + + auto intType = b.getType(); + + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstSeqLen = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + + auto yTy = b.getType( + SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); + + auto YShapeList = b.create( + b.getType(intType), + ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); + + int64_t hDtypeInt = + static_cast(getScalarTypeForType(hTy.getDtype())); + Value hDtypeIntVal = + b.create(loc, b.getI64IntegerAttr(hDtypeInt)); + + Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); + + // Create a for-like PrimLoopOp. + Value maxTripCount = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value loopConditionTrue = b.create(true); + + Type loopIndexType = intType; + auto loop = b.create( + TypeRange({yTy, hTy, cTy}), maxTripCount, loopConditionTrue, + ValueRange({Y_initial, initial_h, initial_c})); + { + OpBuilder::InsertionGuard guard(b); + Block *loopBody = + b.createBlock(&loop.getRegion(), loop.getRegion().begin(), + TypeRange({ + loopIndexType, + yTy, + hTy, + cTy, + }), + {loc, loc, loc, loc} // locs for the loop body arguments + ); + + Value loopIndex = loopBody->getArgument(0); + Value Y_prev = loopBody->getArgument(1); + Value H_prev = loopBody->getArgument(2); + Value C_prev = loopBody->getArgument(3); + + auto xTy = cast(X.getType()); + auto XtType = b.getType( + llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); + + Value Xt = b.create(XtType, X, cstZero, loopIndex); + + auto [H_new, C_new] = + lstm_cell(b, Xt, H_prev, C_prev, weights, activations); + + Type hTyUnsqueezed = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); + Value H_new_unsqueezed = + b.create(hTyUnsqueezed, H_new, cstZero); + + auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + Value Y_new = + b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); + + b.create(loopConditionTrue, + ValueRange({Y_new, H_new, C_new})); + } + LstmLayerOutput output; + output.Y = loop.getResult(0); + output.Y_h = loop.getResult(1); + output.Y_c = loop.getResult(2); + return output; +} +// @brief Expands an ONNX LSTM operation into torch ops. +// +// This function primarily handles the binding of operands and slicing of the +// weight matrix. The majority of the lowering process is managed in the +// lstm_layer and lstm_cell. For the shapes and meanings of the inputs, refer to +// the ONNX LSTM documentation at: +// https://onnx.ai/onnx/operators/onnx__LSTM.html +// The variable names are also consistent with the aforementioned documentation. +// +// This is not e2e tested here but is verified to work numerically downstream in +// SHARK-TestSuite. +// +// TODO: include this test case when the test infrastructure stops initializing +// weights separately for the reference and tested layers. +// @code{.py} +// class LSTMModule(torch.nn.Module): +// def __init__(self): +// super().__init__() +// self.lstm = torch.nn.LSTM(10, 20, 1) +// @export +// @annotate_args([ +// None, +// ([5, 1, 10], torch.float32, True), +// ([1, 1, 20], torch.float32, True), +// ([1, 1, 20], torch.float32, True), +// ]) +// def forward(self, input, h0, c0): +// return self.lstm(input, (h0, c0)) +// +// @register_test_case(module_factory=LSTMModule) +// def LSTMModule_basic(module, tu: TestUtils): +// inputs = torch.zeros(5,1,10) +// h0 = torch.zeros(1,1,20) +// c0 = torch.zeros(1,1,20) +// +// output, (hn, cn) = module.forward(inputs, h0, c0) +// @endcode +// +// @param binder The OpBinder object used for binding operands. +LogicalResult OnnxLstmExpander(OpBinder binder, + ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + + std::string direction; + + ValueTensorType yTy, Y_hType, Y_cType; + if (binder.tensorResultTypeAtIndex(yTy, 0) && + binder.tensorResultTypeAtIndex(Y_hType, 1) && + binder.tensorResultTypeAtIndex(Y_cType, 2)) { + return rewriter.notifyMatchFailure(binder.op, + "At least one outputs must be present"); + } + Value X; + if (binder.tensorOperandAtIndex(X, 0)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor X"); + Value W; + if (binder.tensorOperandAtIndex(W, 1)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor W"); + Value R; + if (binder.tensorOperandAtIndex(R, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor R"); + int64_t hidden_size; + if (binder.s64IntegerAttr(hidden_size, "hidden_size")) + return rewriter.notifyMatchFailure( + binder.op, "Missing required attribute hidden_size"); + + auto xTy = cast(X.getType()); + auto wTy = cast(W.getType()); + + // TODO: add defaults for activation_alpha acticvation_beta attributes + + llvm::SmallVector activationsList; + if (binder.stringArrayAttr(activationsList, "activations")) + return rewriter.notifyMatchFailure( + binder.op, "Missing required attribute; activations"); + + if (!binder.customOpNameStringAttr(direction, "direction", "forward") && + direction != "forward" && direction != "bidirectional") + return rewriter.notifyMatchFailure( + binder.op, "Unsupported direction attribute value. " + "Only 'forward' / 'bidrectional' are supported but '" + + direction + "' is provided."); + int64_t num_directions = 1 + (direction == "bidirectional"); + bool isBidirectional = direction == "bidirectional"; + // There can be backward activations too + // if backward -> look for 6 atcivations (what happens when only three?) + + int64_t num_activations = activationsList.size(); + if (num_activations != 0 && num_activations != 3 && num_activations != 6) { + return rewriter.notifyMatchFailure( + binder.op, "activations must either be empty (default), have 3 elements" + " (forward) or, have 6 elements (bidirectional), but " + + std::to_string(activationsList.size()) + + " are provided."); + } + // TODO : Add checks, defaults and fails for inputs - sequence_lens, P and + // attrs- clip, input_forget, layout + + Value B; + if (binder.tensorOperandAtIndex(B, 3)) { + Value none = b.create(); + Value cstHiddenx8 = b.create( + b.getType(), b.getI64IntegerAttr(8 * hidden_size)); + Value cstNumDir = b.create( + b.getType(), b.getI64IntegerAttr(num_directions)); + auto BType = b.getType( + llvm::SmallVector{num_directions, 8 * hidden_size}, + cast(W.getType()).getDtype()); + Value zerosShapeList = b.create( + b.getType(b.getType()), + SmallVector{cstNumDir, cstHiddenx8}); + B = b.create(BType, zerosShapeList, none, none, none, none); + } + + LstmActivations activations, activationsRev; + // Default case (both forward and reverse) + activations.f = "Sigmoid"; + activations.g = "Tanh"; + activations.h = "Tanh"; + activationsRev.f = "Sigmoid"; + activationsRev.g = "Tanh"; + activationsRev.h = "Tanh"; + + // forward only (also to be added for bidirectional case) + if (num_activations >= 3) { + activations.f = activationsList[0]; + activations.g = activationsList[1]; + activations.h = activationsList[2]; + } + + // bidirectional + if (num_activations == 6) { + activationsRev.f = activationsList[3]; + activationsRev.g = activationsList[4]; + activationsRev.h = activationsList[5]; + } + + float clip; + if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0) + return rewriter.notifyMatchFailure(binder.op, + "clip attribute not supported"); + + int64_t input_forget; + if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) && + input_forget != 0) + return rewriter.notifyMatchFailure( + binder.op, "only input_forget = 0 supported. Got input_forgt = " + + std::to_string(input_forget)); + + int64_t layout; + if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1) + return rewriter.notifyMatchFailure( + binder.op, "invalid value of layout attribute, expecting 0 / 1 got " + + std::to_string(layout)); + + auto XShape = xTy.getSizes(); + int64_t seq_len, batch_size; + if (layout == 0) { + seq_len = XShape[0]; + batch_size = XShape[1]; + } else { + seq_len = XShape[1]; + batch_size = XShape[0]; + } + + int64_t input_size = XShape[2]; + if (num_directions != wTy.getSizes()[0]) + return rewriter.notifyMatchFailure( + binder.op, "num_directions (" + std::to_string(num_directions) + + ") does not match the first dimension of wTy (" + + std::to_string(wTy.getSizes()[0]) + ")"); + + if (4 * hidden_size != wTy.getSizes()[1]) + return rewriter.notifyMatchFailure( + binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) + + ") does not match the second dimension of wTy (" + + std::to_string(wTy.getSizes()[1]) + ")"); + if (wTy.getSizes()[2] != input_size) + return rewriter.notifyMatchFailure( + binder.op, + "The third dimension of wTy (" + std::to_string(wTy.getSizes()[2]) + + ") does not match input_size (" + std::to_string(input_size) + ")"); + + Value W_forward = getDirection(b, 0, W); + Value R_forward = getDirection(b, 0, R); + Value B_forward = getDirection(b, 0, B); + + Value W_reverse, R_reverse, B_reverse; + if (isBidirectional) { + W_reverse = getDirection(b, 1, W); + R_reverse = getDirection(b, 1, R); + B_reverse = getDirection(b, 1, B); + } + + auto hTy = b.getType( + llvm::SmallVector{num_directions, batch_size, hidden_size}, + xTy.getDtype()); + + auto intType = b.getType(); + + Value cstNumDirections = + b.create(intType, b.getI64IntegerAttr(num_directions)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + Value hShape = b.create( + b.getType(intType), + ValueRange({cstNumDirections, cstBatchSize, cstHiddenSize})); + + Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); + + Value initial_h; + if (binder.tensorOperandAtIndex(initial_h, 5)) { + // default created for layout 0 + initial_h = + b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } else { + if (layout == 1) + initial_h = StaticTranspose(b, initial_h, 0, 1); + } + + Value initial_c; + if (binder.tensorOperandAtIndex(initial_c, 6)) { + // default created for layout 0 + initial_c = + b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } else { + if (layout == 1) + initial_c = StaticTranspose(b, initial_c, 0, 1); + } + + // convert X from layout 1 to layout 0 + if (layout == 1) + X = StaticTranspose(b, X, 0, 1); + + // X, initial_h, initial_c are now in layout 0 + + Value initial_h_forward = getDirection(b, 0, initial_h); + Value initial_c_forward = getDirection(b, 0, initial_c); + + Value initial_h_reverse, initial_c_reverse; + if (isBidirectional) { + initial_h_reverse = getDirection(b, 1, initial_h); + initial_c_reverse = getDirection(b, 1, initial_c); + } + + // Everything hereon is for the forward direction (unless in bidirectional if + // block), with the direction dimention squeezed out and all inputs in layout + // 0 format + + LstmWeights weights, weightsRev; // weights and biases + + auto intConst = [&](int64_t val) { + return b.create(intType, b.getI64IntegerAttr(val)); + }; + + // split B into Wb and Rb + Value inputWeightsEndIdx = intConst(4 * hidden_size); + Value recurrentWeightsStartIdx = inputWeightsEndIdx; + Value recurrentWeightsEndIdx = intConst(8 * hidden_size); + auto biasType = b.getType( + llvm::SmallVector{hidden_size * 4}, wTy.getDtype()); + // forward + Value Wb = b.create(biasType, + /*input=*/B_forward, + /*dim=*/cstZero, + /*start=*/cstZero, + /*end=*/inputWeightsEndIdx, + /*step=*/cstOne); + Value Rb = b.create(biasType, + /*input=*/B_forward, + /*dim=*/cstZero, + /*start=*/recurrentWeightsStartIdx, + /*end=*/recurrentWeightsEndIdx, + /*step=*/cstOne); + Value Wb_reverse, Rb_reverse; + if (isBidirectional) { + // reverse + Wb_reverse = b.create(biasType, + /*input=*/B_reverse, + /*dim=*/cstZero, + /*start=*/cstZero, + /*end=*/inputWeightsEndIdx, + /*step=*/cstOne); + Rb_reverse = b.create(biasType, + /*input=*/B_reverse, + /*dim=*/cstZero, + /*start=*/recurrentWeightsStartIdx, + /*end=*/recurrentWeightsEndIdx, + /*step=*/cstOne); + } + + // gate splitting + auto gateBiasType = b.getType( + llvm::SmallVector{hidden_size}, + cast(Wb.getType()).getDtype()); + auto gateWeightsTypeIH = b.getType( + llvm::SmallVector{hidden_size, input_size}, + cast(W_forward.getType()).getDtype()); + auto gateWeightsTypeHH = b.getType( + llvm::SmallVector{hidden_size, hidden_size}, + cast(R_forward.getType()).getDtype()); + + Value inputGateWeightsEndIdx = intConst(hidden_size); + Value outputGateWeightsEndIdx = intConst(2 * hidden_size); + Value forgetGateWeightsEndIdx = intConst(3 * hidden_size); + Value cellGateWeightsEndIdx = intConst(4 * hidden_size); + + auto sliceIOFC = [&](std::function slicerFunction, + Value WoB) { + // slice into 4 components and return tuple + return std::make_tuple( + slicerFunction(cstZero, inputGateWeightsEndIdx, WoB), + slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB), + slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB), + slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB)); + }; + + auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateBiasType, WoB, cstZero, startIdx, + endIdx, cstOne); + }; + std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) = + sliceIOFC(sliceGateBias, Wb); + + if (isBidirectional) + std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f, + weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse); + + auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateBiasType, WoB, cstZero, startIdx, + endIdx, cstOne); + }; + std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) = + sliceIOFC(sliceGateBiasR, Rb); + + if (isBidirectional) + std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f, + weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse); + + auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateWeightsTypeIH, WoB, cstZero, + startIdx, endIdx, cstOne); + }; + std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) = + sliceIOFC(sliceGateWeightsIH, W_forward); + + if (isBidirectional) + std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) = + sliceIOFC(sliceGateWeightsIH, W_reverse); + + auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) { + return b.create(gateWeightsTypeHH, WoB, cstZero, + startIdx, endIdx, cstOne); + }; + + std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) = + sliceIOFC(sliceGateWeightsHH, R_forward); + + if (isBidirectional) + std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) = + sliceIOFC(sliceGateWeightsHH, R_reverse); + + LstmLayerOutput lstmLayerOutput = lstm_layer( + b, X, initial_h_forward, initial_c_forward, weights, activations); + + Value Y_h_result, Y_c_result, Y_result; + + // if forward (unidirectional) unsqueeze and output + auto YallDtype = + cast(lstmLayerOutput.Y_h.getType()).getDtype(); + auto Y_h_Y_c_uni_type = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, YallDtype); + auto Y_uni_type = b.getType( + llvm::SmallVector{seq_len, 1, batch_size, hidden_size}, + YallDtype); + auto Y_h_Y_c_res_type = b.getType( + llvm::SmallVector{num_directions, batch_size, hidden_size}, + YallDtype); + auto Y_res_type = b.getType( + llvm::SmallVector{seq_len, num_directions, batch_size, + hidden_size}, + YallDtype); + + Value Y_h_forward = + b.create(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero); + + Value Y_c_forward = + b.create(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero); + + // unsqueeze num_directions dim1 of Y + // to create the onnx.LSTM output shape [seq_length, num_directions, + // batch_size, hidden_size] + Value Y_forward = + b.create(Y_uni_type, lstmLayerOutput.Y, cstOne); + + Y_result = Y_forward; + Y_h_result = Y_h_forward; + Y_c_result = Y_c_forward; + + // add bidrectional reverse layer + // this is just flip X, lstm layer, flip results, stack + // flip X + Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped, + Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list; + LstmLayerOutput revLstmLayerOutput; + if (isBidirectional) { + dim0 = b.create(b.getType(intType), + SmallVector{cstZero}); + X_reverse = b.create(xTy, X, dim0); // flip along seq_len dim + revLstmLayerOutput = + lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse, + weightsRev, activationsRev); + + // unsqueeze Y_rev, Y_h_rev, Y_c_rev + Y_h_reverse = b.create(Y_h_Y_c_uni_type, + revLstmLayerOutput.Y_h, cstZero); + Y_c_reverse = b.create(Y_h_Y_c_uni_type, + revLstmLayerOutput.Y_c, cstZero); + Y_reverse_unflipped = + b.create(Y_uni_type, revLstmLayerOutput.Y, cstOne); + + // flip Y_rev on dim 0 [seq_len] + Y_reverse = b.create(Y_uni_type, Y_reverse_unflipped, dim0); + + // Concat forward and reverse results on dim 1 + Y_output_list = + b.create(b.getType(Y_uni_type), + SmallVector{Y_forward, Y_reverse}); + Y_result = b.create(Y_res_type, Y_output_list, cstOne); + + // Concat forward and reverse results on dim 0 + Y_h_output_list = b.create( + b.getType(Y_h_Y_c_uni_type), + SmallVector{Y_h_forward, Y_h_reverse}); + Y_h_result = + b.create(Y_h_Y_c_res_type, Y_h_output_list, cstZero); + + Y_c_output_list = b.create( + b.getType(Y_h_Y_c_uni_type), + SmallVector{Y_c_forward, Y_c_reverse}); + Y_c_result = + b.create(Y_h_Y_c_res_type, Y_c_output_list, cstZero); + } + + if (layout == 1) { + // Update Y, Y_h, Y_c results to layout 1 + Y_result = StaticTranspose(b, Y_result, 1, 2); + Y_result = StaticTranspose(b, Y_result, 0, 1); + Y_h_result = StaticTranspose(b, Y_h_result, 0, 1); + Y_c_result = StaticTranspose(b, Y_c_result, 0, 1); + } + + // Only add outputs specified in onnx output node + SmallVector actualOutputs = {Y_result, Y_h_result, Y_c_result}, + outputs; + ValueTensorType resTy; + for (int i = 0; i < binder.getNumResults(); ++i) { + if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) { + outputs.push_back(cstNone); + } else { + outputs.push_back(actualOutputs[i]); + } + } + + rewriter.replaceOp(binder.op, outputs); + return success(); +} + +// W[zrh] - W parameter weight matrix for update, reset, and hidden gates +// R[zrh] - R recurrence weight matrix for update, reset, and hidden gates +// Wb[zrh] - W bias vectors for update, reset, and hidden gates +// Rb[zrh] - R bias vectors for update, reset, and hidden gates +// backwards currently not supported + +struct GruWeights { + Value Wz; + Value Wr; + Value Wh; + Value Rz; + Value Rr; + Value Rh; + Value Wbz; + Value Wbr; + Value Wbh; + Value Rbz; + Value Rbr; + Value Rbh; +}; + +struct GruLayerOutput { + Value Y; + Value Y_h; +}; + +struct GruActivations { + std::string f; + std::string g; +}; + +Value gru_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, + GruWeights weights, GruActivations activations, + bool linear_before_reset) { + auto hTy = cast(H_prev.getType()); + + auto intType = b.getType(); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + Value z_w = b.create(hTy, Xt, weights.Wz, weights.Wbz); + Value z_r = b.create(hTy, H_prev, weights.Rz, weights.Rbz); + Value z_pre = b.create(hTy, z_w, z_r, cstOne); + Value zt = createActivationByName(b, activations.f, z_pre); + + Value r_w = b.create(hTy, Xt, weights.Wr, weights.Wbr); + Value r_r = b.create(hTy, H_prev, weights.Rr, weights.Rbr); + Value r_pre = b.create(hTy, r_w, r_r, cstOne); + Value rt = createActivationByName(b, activations.f, r_pre); + + Value h_w = b.create(hTy, Xt, weights.Wh, weights.Wbh); + Value h_r; + if (linear_before_reset) { + // when linear_before_reset = 1, multiply r with H_prev to reset + // before applying linear layer + Value h_linear = + b.create(hTy, H_prev, weights.Rh, weights.Rbh); + h_r = b.create(hTy, h_linear, rt); + } else { + // otherwise, multiply first and then apply linear layer + Value h_reset = b.create(hTy, H_prev, rt); + h_r = b.create(hTy, h_reset, weights.Rh, weights.Rbh); + } + Value h_pre = b.create(hTy, h_w, h_r, cstOne); + Value ht = createActivationByName(b, activations.g, h_pre); + + // Create a constant tensor filled with ones, matching the shape of zt + Value cstNone = b.create(); + int64_t typeInt = (int64_t)getScalarTypeForType(hTy.getDtype()); + Value dtype = b.create(b.getI64IntegerAttr(typeInt)); + Value ones = b.create( + hTy, zt, dtype, /*layout=*/cstNone, + /*device=*/cstNone, /*pin_memory=*/cstNone, /*memory_format=*/cstNone); + + Value one_minus_zt = b.create(hTy, ones, zt, cstOne); + Value ht_scaled = b.create(hTy, one_minus_zt, ht); + Value H_prev_zt = b.create(hTy, H_prev, zt); + Value H_new = b.create(hTy, ht_scaled, H_prev_zt, cstOne); + + return H_new; +} + +GruLayerOutput gru_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, + GruWeights weights, GruActivations activations, + bool linear_before_reset) { + Location loc = b.getLoc(); + + auto xTy = cast(X.getType()); + auto hTy = cast(initial_h.getType()); + + // Get sizes and store them in intermediate variables + auto xTySizes = xTy.getSizes(); + auto hTySizes = hTy.getSizes(); + + int64_t seq_len = xTySizes[0]; + int64_t batch_size = xTySizes[1]; + int64_t input_size = xTySizes[2]; + int64_t hidden_size = hTySizes[1]; + + auto intType = b.getType(); + + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstSeqLen = + b.create(intType, b.getI64IntegerAttr(seq_len)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + + auto yTy = b.getType( + SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); + + auto YShapeList = b.create( + b.getType(intType), + ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); + + int64_t hDtypeInt = + static_cast(getScalarTypeForType(hTy.getDtype())); + Value hDtypeIntVal = b.create(b.getI64IntegerAttr(hDtypeInt)); + + Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); + + Value maxTripCount = cstSeqLen; + Value loopConditionTrue = b.create(true); + + Type loopIndexType = intType; + + auto loop = b.create(TypeRange({yTy, hTy}), maxTripCount, + loopConditionTrue, + ValueRange({Y_initial, initial_h})); + + { + OpBuilder::InsertionGuard guard(b); + Block *loopBody = + b.createBlock(&loop.getRegion(), loop.getRegion().begin(), + TypeRange({loopIndexType, yTy, hTy}), {loc, loc, loc}); + + Value loopIndex = loopBody->getArgument(0); + Value Y_prev = loopBody->getArgument(1); + Value H_prev = loopBody->getArgument(2); + + auto XtType = b.getType( + llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); + + Value Xt = b.create(XtType, X, cstZero, loopIndex); + + Value H_new = + gru_cell(b, Xt, H_prev, weights, activations, linear_before_reset); + + Type hTyUnsqueezed = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); + Value H_new_unsqueezed = + b.create(hTyUnsqueezed, H_new, cstZero); + + auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + Value Y_new = + b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); + + b.create(loopConditionTrue, + ValueRange({Y_new, H_new})); + } + + GruLayerOutput output; + output.Y = loop.getResult(0); + output.Y_h = loop.getResult(1); + + return output; +} + +LogicalResult OnnxGruExpander(OpBinder binder, + ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + + auto intType = b.getType(); + Value cstNone = b.create(); + Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); + Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + + // Binding arguments + ValueTensorType yTy, Y_hType; + if (binder.tensorResultTypeAtIndex(yTy, 0) && + binder.tensorResultTypeAtIndex(Y_hType, 1)) { + return rewriter.notifyMatchFailure(binder.op, + "At least one output must be present"); + } + + Value X, W, R, B, initial_h, sequence_lens; + if (binder.tensorOperandAtIndex(X, 0) || binder.tensorOperandAtIndex(W, 1) || + binder.tensorOperandAtIndex(R, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Missing required input tensor"); + + if (binder.tensorOperandAtIndex(B, 3)) { + // if no b found, set to null and create one later + B = nullptr; + } + + int64_t hidden_size; + if (binder.s64IntegerAttr(hidden_size, "hidden_size")) + return rewriter.notifyMatchFailure( + binder.op, "Missing required attribute hidden_size"); + + auto xTy = cast(X.getType()); + auto wTy = cast(W.getType()); + + // Setting up activations + GruActivations activations; + activations.f = "Sigmoid"; + activations.g = "Tanh"; + + llvm::SmallVector activationsList; + if (!binder.stringArrayAttr(activationsList, "activations") && + activationsList.size() == 2) { + activations.f = activationsList[0]; + activations.g = activationsList[1]; + } else if (activationsList.size() > 0) { + return rewriter.notifyMatchFailure( + binder.op, "Unsupported number of activation functions"); + } + + // Other attributes + int64_t layout; + if (binder.s64IntegerAttr(layout, "layout", 0)) + return rewriter.notifyMatchFailure(binder.op, + "Unsupported layout attribute type."); + + std::string direction; + if (!binder.customOpNameStringAttr(direction, "direction", "forward") && + direction != "forward") + return rewriter.notifyMatchFailure(binder.op, + "Unsupported direction attribute value"); + + int64_t num_directions = direction == "bidirectional" ? 2 : 1; + // Validations + auto XShape = xTy.getSizes(); + int64_t batch_size = (layout == 0) ? XShape[1] : XShape[0]; + int64_t seq_len = (layout == 0) ? XShape[0] : XShape[1]; + int64_t input_size = XShape[2]; + + std::ostringstream oss; + + if (num_directions != 1) { + oss << "Expected num_directions to be 1, but got " << num_directions + << ". "; + } + + if (hidden_size * 3 != wTy.getSizes()[1]) { + oss << "Expected dim 1 of W to be the same as 3*hidden_size " + << 3 * hidden_size << ", but got " << wTy.getSizes()[1] << ". "; + } + + if (wTy.getSizes()[2] != input_size) { + oss << "Expected wTy.getSizes()[2] to be " << input_size << ", but got " + << wTy.getSizes()[2] << ". "; + } + + if (!oss.str().empty()) { + return rewriter.notifyMatchFailure(binder.op, oss.str()); + } + + // Setting up initial_h + auto hTy = b.getType( + llvm::SmallVector{num_directions, batch_size, hidden_size}, + xTy.getDtype()); + + if (binder.tensorOperandAtIndex(initial_h, 5)) { + Value cstNumDirections = + b.create(intType, b.getI64IntegerAttr(num_directions)); + Value cstBatchSize = + b.create(intType, b.getI64IntegerAttr(batch_size)); + Value cstHiddenSize = + b.create(intType, b.getI64IntegerAttr(hidden_size)); + Value hShape = b.create( + b.getType(intType), + ValueRange({cstNumDirections, cstBatchSize, cstHiddenSize})); + Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); + initial_h = + b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + } else { + if (layout == 1) { + initial_h = StaticTranspose(b, initial_h, 0, 1); + } + } + + if (binder.tensorOperandAtIndex(sequence_lens, 4)) + sequence_lens = b.create(); + + float clip; + if (!binder.f32FloatAttr(clip, "clip") && clip != 0.0f) + return rewriter.notifyMatchFailure( + binder.op, "Clip not supported (specified with a value of " + + std::to_string(clip) + ")"); + + int64_t linear_before_reset_int; + if (binder.s64IntegerAttr(linear_before_reset_int, "linear_before_reset", 0)) + linear_before_reset_int = 0; + bool linear_before_reset = linear_before_reset_int != 0; + + // fill in B + Value cstXDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); + if (B == nullptr) { + SmallVector BShape = {num_directions, 6 * hidden_size}; + SmallVector BShapeListContents = { + b.create(intType, b.getI64IntegerAttr(num_directions)), + b.create(intType, b.getI64IntegerAttr(6 * hidden_size))}; + Value BShapeList = b.create( + b.getType(intType), BShapeListContents); + auto BType = b.getType(BShape, wTy.getDtype()); + B = b.create(BType, BShapeList, cstXDtype, cstNone, + cstNone, cstNone); + } + + Value W_forward = getDirection(b, 0, W); + Value R_forward = getDirection(b, 0, R); + Value B_forward = getDirection(b, 0, B); + Value initial_h_forward = getDirection(b, 0, initial_h); + + GruWeights weights; + + // Slice a tensor into numSlices slices of size sliceSize + // This is used for slicing the weights & biases into the individual gates + auto sliceTensor = [&](Value tensor, int64_t sliceSize, int64_t numSlices, + ValueTensorType sliceType) { + SmallVector slices; + for (int64_t i = 0; i < numSlices; ++i) { + Value start = + b.create(intType, b.getI64IntegerAttr(i * sliceSize)); + Value end = b.create( + intType, b.getI64IntegerAttr((i + 1) * sliceSize)); + + Value slice = b.create(sliceType, tensor, + cstZero, // dim to slice on + start, end, + cstOne // step + ); + + slices.push_back(slice); + } + return slices; + }; + + // Slice W + auto wSliceType = b.getType( + llvm::SmallVector{hidden_size, input_size}, wTy.getDtype()); + auto W_slices = sliceTensor(W_forward, hidden_size, 3, wSliceType); + std::tie(weights.Wz, weights.Wr, weights.Wh) = + std::make_tuple(W_slices[0], W_slices[1], W_slices[2]); + + // Slice R + auto rSliceType = b.getType( + llvm::SmallVector{hidden_size, hidden_size}, wTy.getDtype()); + auto R_slices = sliceTensor(R_forward, hidden_size, 3, rSliceType); + std::tie(weights.Rz, weights.Rr, weights.Rh) = + std::make_tuple(R_slices[0], R_slices[1], R_slices[2]); + + // Slice B + auto bSliceType = b.getType( + llvm::SmallVector{hidden_size}, wTy.getDtype()); + auto B_slices = sliceTensor(B_forward, hidden_size, 6, bSliceType); + std::tie(weights.Wbz, weights.Wbr, weights.Wbh, weights.Rbz, weights.Rbr, + weights.Rbh) = + std::make_tuple(B_slices[0], B_slices[1], B_slices[2], B_slices[3], + B_slices[4], B_slices[5]); + + // Process inputs based on layout + if (layout == 1) { + X = StaticTranspose(b, X, 0, 1); + } + + // Weights and biases ready. Calling GRU layer to insert the actual ops. + GruLayerOutput gruLayerOutput = gru_layer(b, X, initial_h_forward, weights, + activations, linear_before_reset); + + // Process outputs based on layout + Value Y_final; + if (binder.tensorResultTypeAtIndex(yTy, 0)) { + Y_final = cstNone; + } else { + if (layout == 0) { + Y_final = b.create(yTy, gruLayerOutput.Y, cstOne); + } else { + Type yTy_original = b.getType( + llvm::SmallVector{seq_len, 1, batch_size, hidden_size}, + yTy.getDtype()); + Y_final = + b.create(yTy_original, gruLayerOutput.Y, cstOne); + Y_final = StaticTranspose(b, Y_final, 1, 2); + Y_final = StaticTranspose(b, Y_final, 0, 1); + } + } + + Value Y_h_final; + if (binder.tensorResultTypeAtIndex(Y_hType, 1)) { + Y_h_final = cstNone; + } else { + if (layout == 0) { + Y_h_final = + b.create(Y_hType, gruLayerOutput.Y_h, cstZero); + } else { + Type y_hTy_original = b.getType( + llvm::SmallVector{1, batch_size, hidden_size}, + Y_hType.getDtype()); + Y_h_final = b.create(y_hTy_original, gruLayerOutput.Y_h, + cstZero); + Y_h_final = StaticTranspose(b, Y_h_final, 0, 1); + } + } + + rewriter.replaceOp(binder.op, mlir::ValueRange{Y_final, Y_h_final}); + return success(); +} + +} // namespace mlir::torch::onnx_c diff --git a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp index 6ca7824165d3..b4b9e4b3ddfc 100644 --- a/lib/Conversion/TorchOnnxToTorch/Patterns.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Patterns.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "mlir/IR/BuiltinAttributes.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -24,12 +25,24 @@ LogicalResult OnnxCustomOpConversionPattern::matchAndRewrite( auto foundIt = namedHandlers.find(op.getNameAttr()); if (foundIt == namedHandlers.end()) return failure(); + // The domainVersion comes from the function attribute + // torch.onnx_meta.opset_version and defines the opset for all ONNX ops the + // function contains. Absent this attribute, domainVersion is 0. + int64_t opDomainVersion = domainVersion; + // If the op has an individual version (torch.onnx_meta.version attribute), it + // overrides the function's domainVersion and will be used for matching later + // here. + if (auto attr = op->getAttrOfType("torch.onnx_meta.version")) { + assert(cast(attr.getType()).isSigned()); + opDomainVersion = attr.getSInt(); + } auto ®gies = foundIt->second; for (const HandlerReg ® : reggies) { - if (domainVersion < reg.sinceVersion) { + if (opDomainVersion < reg.sinceVersion) { LLVM_DEBUG(dbgs() << ": skipping conversion " << foundIt->first << ", sinceVersion=" << reg.sinceVersion - << ", for domainVersion=" << domainVersion << "\n"); + << ", for domainVersion=" << domainVersion + << ", opDomainVersion=" << opDomainVersion << "\n"); continue; } if (succeeded(reg.callback(OpBinder(op), rewriter))) { diff --git a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp index ea890bf0f4b6..fa2b95c0c29f 100644 --- a/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp +++ b/lib/Conversion/TorchOnnxToTorch/TorchOnnxToTorch.cpp @@ -45,12 +45,6 @@ class ConvertTorchOnnxToTorch // Populate our patterns for each handled domain. int64_t defaultOpsetVersion = getDefaultOpsetVersion(getOperation()); - if (defaultOpsetVersion == 0) { - emitError(getOperation().getLoc()) - << "function is missing onnx opset version attribute " - "(torch.onnx_meta.opset_version)"; - return signalPassFailure(); - } auto defaultDomainPatterns = std::make_unique( diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index dec13490666e..5361089d69d1 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; @@ -16,7 +17,7 @@ using namespace mlir::torch::onnx_c; Value mlir::torch::onnx_c::createConstantIntList( OpBinder binder, ConversionPatternRewriter &rewriter, - SmallVector cstInput) { + ArrayRef cstInput) { SmallVector cstValue; for (int64_t i : cstInput) { cstValue.push_back(rewriter.create( @@ -28,7 +29,8 @@ Value mlir::torch::onnx_c::createConstantIntList( cstValue); } -Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { +Torch::ValueTensorType +mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { Torch::ValueTensorType tty = dyn_cast(ty); if (!tty) return nullptr; @@ -40,6 +42,8 @@ Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { dty = Torch::QUInt8Type::get(ctx); if (dty.isSignedInteger(8)) dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(16)) + dty = Torch::QInt16Type::get(ctx); if (dty.isSignedInteger(32)) dty = Torch::QInt32Type::get(ctx); @@ -97,3 +101,44 @@ mlir::torch::onnx_c::onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { return dtypeIntTorch; } + +LogicalResult mlir::torch::onnx_c::createTorchTransposeOp( + ConversionPatternRewriter &rewriter, Location loc, Value input, + int64_t dimA, int64_t dimB, Value &transposed) { + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} + +LogicalResult mlir::torch::onnx_c::createTorchPermuteOp( + OpBinder binder, ConversionPatternRewriter &rewriter, Location loc, + Value input, SmallVector permuteDims, Value &permuted) { + Type permutedType; + if (failed( + Torch::getPermutedType(cast(input.getType()), + permuteDims, permutedType))) + return failure(); + Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims); + permuted = rewriter.create(loc, permutedType, input, + permuteDimsList); + return success(); +} + +Value mlir::torch::onnx_c::createActivationByName(ImplicitLocOpBuilder &b, + StringRef name, Value input) { + if (name == "Sigmoid") + return b.create(input.getType(), input); + if (name == "Tanh") + return b.create(input.getType(), input); + if (name == "Relu") + return b.create(input.getType(), input); + llvm_unreachable("Unsupported activation function"); +} diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 2703d48724cf..69d585c69ba4 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -12,17 +12,14 @@ #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; @@ -75,8 +72,33 @@ class ConvertAtenBinaryOp : public OpConversionPattern { matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), - adaptor.getB()); + Value a = adaptor.getA(); + Value b = adaptor.getB(); + if (llvm::is_one_of::value || + llvm::is_one_of::value) + b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); + if (llvm::is_one_of::value) + a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType()); + rewriter.template replaceOpWithNewOp(op, a, b); + return success(); + } +}; +} // namespace + +namespace { +class ConvertAtenNegIntOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenNegIntOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value a = adaptor.getA(); + rewriter.replaceOpWithNewOp( + op, + rewriter.create(op.getLoc(), /*value=*/0, + /*bitwidth=*/64), + a); return success(); } }; @@ -258,6 +280,25 @@ class ConvertAtenCastOp : public OpConversionPattern { }; } // namespace +namespace { +template +class ConvertAtenScalarArithOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenOp op, + typename OpConversionPattern::OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type resultType = + this->getTypeConverter()->convertType(op->getResult(0).getType()); + Value result = + convertScalarToDtype(rewriter, op.getLoc(), adaptor.getA(), resultType); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenAddOp : public OpConversionPattern { public: @@ -413,11 +454,17 @@ class ConvertTorchToArith patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); + patterns.add< + ConvertAtenFloatComparisonOp>( + typeConverter, context); patterns.add< ConvertAtenFloatComparisonOp>( typeConverter, context); + patterns.add< + ConvertAtenFloatComparisonOp>( + typeConverter, context); patterns.add>( typeConverter, context); @@ -439,20 +486,32 @@ class ConvertTorchToArith target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - - target.addIllegalOp(); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); @@ -473,6 +532,13 @@ class ConvertTorchToArith patterns.add>( typeConverter, context); target.addIllegalOp(); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); patterns .add>( typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d8dd75a9a233..b8c20bc73f65 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -8,12 +8,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -21,11 +18,9 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -38,12 +33,14 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; static int64_t productReduce(ArrayRef a) { - return accumulate(a.begin(), a.end(), /*init=*/1, std::multiplies()); + return accumulate(a.begin(), a.end(), /*init=*/static_cast(1), + std::multiplies()); } template LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + int64_t &dim, SmallVector &resultShape, SmallVector &offsets, SmallVector &strides) { @@ -55,7 +52,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value one = rewriter.create(loc, 1); Value negone = rewriter.create(loc, -1); - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); @@ -665,7 +661,8 @@ class ConvertAtenUnflattenIntOp "Expected input type having sizes"); } int inputRank = inputTensorType.getSizes().size(); - int outputRank = outputTensorType.getSizes().size(); + auto outputSizes = outputTensorType.getSizes(); + int outputRank = outputSizes.size(); int64_t dimInt; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) @@ -679,23 +676,64 @@ class ConvertAtenUnflattenIntOp auto sizesOp = op.getSizes().getDefiningOp(); int numSizes = sizesOp.getNumOperands(); - SmallVector reassociations(inputRank); - if (inputRank > 0) { - for (int i = 0; i < dimInt; ++i) - reassociations[i].push_back(i); - - for (int i = 0; i < numSizes; ++i) - reassociations[dimInt].push_back(i + dimInt); - - for (int i = dimInt + numSizes; i < outputRank; ++i) - reassociations[i - numSizes + 1].push_back(i); + int64_t numDynamicReassocDims = 0; + for (int64_t i = dimInt; i < dimInt + numSizes; i++) { + if (outputSizes[i] == Torch::kUnknownSize) + numDynamicReassocDims++; } + SmallVector reassocSizes; + if (!getListConstructElements(op.getSizes(), reassocSizes) && + numDynamicReassocDims > 1) + return rewriter.notifyMatchFailure( + op, "Must be able to either infer expansion dims, or retrieve them " + "from list construct"); + auto expandTy = getTypeConverter()->convertType(outputTensorType); - auto expand = rewriter - .create( - loc, expandTy, adaptor.getSelf(), reassociations) - .getResult(); + Value expand; + // When there are less than two dynamic reassociation dims, this will lower + // to tensor.expand_shape. Otherwise, this lowers to tensor.reshape. + // TODO: in the numDynamicReassocDims >= 2 case, lower to expand_shape with + // explicitly provided outputShape once + // https://github.com/iree-org/iree/issues/17760 is resolved. + if (numDynamicReassocDims < 2) { + SmallVector reassociations(inputRank); + if (inputRank > 0) { + for (int i = 0; i < dimInt; ++i) + reassociations[i].push_back(i); + for (int i = 0; i < numSizes; ++i) + reassociations[dimInt].push_back(i + dimInt); + for (int i = dimInt + numSizes; i < outputRank; ++i) + reassociations[i - numSizes + 1].push_back(i); + } + expand = rewriter + .create( + loc, expandTy, adaptor.getSelf(), reassociations) + .getResult(); + } else { + reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), + reassocSizes); + SmallVector inputShape = + getTensorSizes(rewriter, loc, adaptor.getSelf()); + inputShape = castIndexVectorToInt64Vector(rewriter, loc, inputShape); + SmallVector outputShape(inputShape.begin(), + inputShape.begin() + dimInt); + if (inputRank > 0) { + for (int i = 0; i < numSizes; ++i) + outputShape.push_back(reassocSizes[i]); + for (int i = dimInt + numSizes; i < outputRank; ++i) + outputShape.push_back(inputShape[i - numSizes + 1]); + } + + RankedTensorType shapeType = RankedTensorType::get( + ArrayRef{outputRank}, rewriter.getIntegerType(64)); + Value shapeValue = + rewriter.create(loc, shapeType, outputShape); + expand = rewriter + .create(loc, expandTy, adaptor.getSelf(), + shapeValue) + .getResult(); + } rewriter.replaceOp(op, expand); return success(); } @@ -1604,62 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Value input = adaptor.getSelf(); - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - - if (inputRank == 0) { - return rewriter.notifyMatchFailure( - op, "zero input rank should have been handled by the folder"); - } - int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - - // TODO: Handle the case where the dim(th) dimension is dynamic. - if (inputType.isDynamicDim(dim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: dim(th) dimension is not expected to be dynamic"); - } - - const TypeConverter *typeConverter = getTypeConverter(); - auto resultType = - cast(typeConverter->convertType(op.getType())); - int64_t resultRank = resultType.getRank(); - // If the dim(th) dimension of operand tensor type is not statically unit, - // `aten.squeeze` will behave as an identity operation. - if (inputType.getDimSize(dim) != 1) { - rewriter.replaceOpWithNewOp(op, resultType, input); - return success(); + auto squeezeTensorInfo = + squeezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(squeezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - SmallVector reassociationMap(resultRank); - bool alreadyCrossedSqueezedDim = false; - for (int i = 0; i != resultRank; i++) { - if (alreadyCrossedSqueezedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (dim != 0 && i != dim - 1) - continue; - - alreadyCrossedSqueezedDim = true; - if (dim == 0) - reassociationMap[0].push_back(1); - if (i == dim - 1) - reassociationMap[i].push_back(dim); - } - } - // Note: In case the operand tensor type is of unit rank and is statically - // shaped with unit dimension, the `reassociationMap` will be empty and the - // input will be collapsed to a 0-D tensor. - rewriter.replaceOpWithNewOp(op, resultType, input, - reassociationMap); + rewriter.replaceOp(op, squeezeTensorInfo.value()); return success(); } }; @@ -1677,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); - auto inputRank = - cast(adaptor.getSelf().getType()).getRank(); - dim = toPositiveDim(dim, inputRank + 1); - if (!isValidDim(dim, inputRank + 1)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector reassociationMap(inputRank); - // From the perspective of the reassociation map, the situation of - // unsqueezing before or after the last dimension is symmetrical. - // Normalize it to the "before" case. - // The 0 case is special here, since there is no last dimension to insert - // before -- we simply rely on the loop below iterating 0 times. - if (dim == inputRank && inputRank != 0) - dim = inputRank - 1; - bool alreadyCrossedExpandedDim = false; - for (int i = 0; i != inputRank; i++) { - if (alreadyCrossedExpandedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (i == dim) { - reassociationMap[i].push_back(i + 1); - alreadyCrossedExpandedDim = true; - } - } + auto unsqueezeTensorInfo = + unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim); + if (failed(unsqueezeTensorInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); } - auto resultType = cast( - getTypeConverter()->convertType(op->getResult(0).getType())); - rewriter.replaceOpWithNewOp( - op, resultType, adaptor.getSelf(), reassociationMap); + + rewriter.replaceOp(op, unsqueezeTensorInfo.value()); return success(); } }; @@ -1735,6 +1708,10 @@ class ConvertAtenTransposeIntOp auto inputRank = inType.getRank(); auto outType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); + if (inputRank <= 1 && inType == outType) { + rewriter.replaceOp(op, {adaptor.getSelf()}); + return success(); + } auto elementType = inType.getElementType(); dim0 = toPositiveDim(dim0, inputRank); @@ -1753,32 +1730,16 @@ class ConvertAtenTransposeIntOp Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); - SmallVector idExprs; - SmallVector swapExprs; - for (auto i = 0; i < inputRank; i++) - idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); - for (auto i = 0; i < inputRank; i++) { - if (i == dim0) - swapExprs.push_back(idExprs[dim1]); - else if (i == dim1) - swapExprs.push_back(idExprs[dim0]); - else - swapExprs.push_back(idExprs[i]); - } - SmallVector indexingMaps = { - AffineMap::get(inputRank, 0, idExprs, op.getContext()), - AffineMap::get(inputRank, 0, swapExprs, op.getContext())}; - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - auto transpose = rewriter - .create( - loc, outVector.getType(), inVector, outVector, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); + SmallVector permutation(inputRank); + std::iota(permutation.begin(), permutation.end(), 0); + permutation[dim0] = dim1; + permutation[dim1] = dim0; + + auto transpose = + rewriter + .create(loc, inVector, outVector, permutation) + .getResult(); rewriter.replaceOpWithNewOp(op, outType, transpose); return success(); } @@ -1800,55 +1761,15 @@ class ConvertAtenPermuteOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "all dimensions must be constant"); Value inVector = adaptor.getSelf(); - auto inType = cast(inVector.getType()); - int64_t inputRank = inType.getRank(); - auto outType = cast( - getTypeConverter()->convertType(op->getResult(0).getType())); - Type elementType = inType.getElementType(); - - // Check if the dimensions are a valid constants. - int64_t numDimensions = dimensions.size(); - if (inputRank != numDimensions) + Value result; + if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(), + dimensions, inVector, result))) return rewriter.notifyMatchFailure( - op, "size of `dims` must be equal to the rank of the input"); - for (unsigned i = 0; i < numDimensions; i++) { - if (dimensions[i] < 0) - dimensions[i] = toPositiveDim(dimensions[i], inputRank); - if (!isValidDim(dimensions[i], inputRank)) - return rewriter.notifyMatchFailure(op, "dimension out of range"); - } - - Location loc = op.getLoc(); + op, "failed to perform permutation of tensor"); - SmallVector outputDims; - for (unsigned i = 0; i < inputRank; i++) - outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i])); - - Value outVector = rewriter.create( - loc, getAsOpFoldResult(outputDims), elementType); - SmallVector idExprs; - SmallVector swapExprs; - for (unsigned i = 0; i < inputRank; i++) - idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); - for (unsigned i = 0; i < inputRank; i++) - swapExprs.push_back(idExprs[dimensions[i]]); - - AffineMap inputMap = - AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); - AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0, - swapExprs, op->getContext()); - SmallVector indexingMaps{inputMap, outputMap}; - SmallVector iteratorTypes( - inputRank, utils::IteratorType::parallel); - auto transpose = rewriter - .create( - loc, outVector.getType(), inVector, outVector, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); - rewriter.replaceOpWithNewOp(op, outType, transpose); + auto outType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + rewriter.replaceOpWithNewOp(op, outType, result); return success(); } }; @@ -1868,21 +1789,54 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } + // If stride is negative, then flip the input tensor corresponding to that + // dim, update the stride for flipped tensor by multiplying it by -1, and + // update the offset as follows: + // flipped_offset = input_shape[dim] - (result_shape[dim] * flipped_stride) + // + // For example: + // Input = [0, 1, 2, 3, 4, 5] + // stride = [-2], result_shape = [2], offset = [3] + // Result = [3, 1] + // After flipping: + // Input = [5, 4, 3, 2, 1, 0] + // stride = [2], result_shape = [2], offset = [6 - (2 * 2)] = [2] + // Result = [3, 1] + + Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input, + SmallVector{dim}); + Value cstDim = rewriter.create(loc, dim); + Value zero = rewriter.create(loc, 0); + Value isNegativeStride = rewriter.create( + loc, arith::CmpIPredicate::slt, strides[dim], zero); + strides[dim] = rewriter.create(loc, strides[dim]); + Value resShapeMulStride = + rewriter.create(loc, resultShape[dim], strides[dim]); + Value inputDim = rewriter.create(loc, input, cstDim); + Value flippedOffset = + rewriter.create(loc, inputDim, resShapeMulStride); + offsets[dim] = rewriter.create( + loc, isNegativeStride, flippedOffset, offsets[dim]); + + input = rewriter.create(loc, isNegativeStride, + flippedInput, input); + + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); + auto sliceType = RankedTensorType::get( + dynShape, resultType.getElementType(), resultType.getEncoding()); Value result = rewriter.create( - loc, input, offsets, resultShape, strides); + loc, sliceType, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); @@ -2105,16 +2059,14 @@ class ConvertAtenSliceScatterOp auto input = adaptor.getSelf(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; + SmallVector resultShape, offsets, strides; + int64_t dim; if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { + op, adaptor, rewriter, dim, resultShape, offsets, strides))) { return failure(); } @@ -2341,9 +2293,8 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { op, "diagonal dimensions cannot be identical"); Type elementType = inputType.getElementType(); - RankedTensorType outputType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType outputType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Location loc = op.getLoc(); Value dim1Size, dim2Size; @@ -2579,9 +2530,8 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { }) .getResult(0); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, resultTensor); return success(); @@ -2589,6 +2539,167 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenUnfoldOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUnfoldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto self = adaptor.getSelf(); + RankedTensorType selfType = cast(self.getType()); + + int64_t dimension; + if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dimension))) { + return rewriter.notifyMatchFailure(op, + "only support constant int dimension"); + } + int64_t size; + if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) { + return rewriter.notifyMatchFailure(op, "only support constant int size"); + } + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { + return rewriter.notifyMatchFailure(op, "only support constant int step"); + } + + if (step <= 0) { + return rewriter.notifyMatchFailure(op, "step must be greater than zero."); + } + + int64_t selfRank = selfType.getRank(); + + // Zero-Rank case + if (selfRank == 0) { + // Empty tensor + if (size == 0) { + RankedTensorType resultType = + RankedTensorType::get({0}, selfType.getElementType()); + Value emptyTensor = rewriter.create( + loc, resultType.getShape(), resultType.getElementType()); + + rewriter.replaceOp(op, emptyTensor); + return success(); + } + + Value unsqueezedSelf = rewriter.create( + loc, RankedTensorType::get({1}, selfType.getElementType()), self, + ArrayRef{}); + rewriter.replaceOp(op, unsqueezedSelf); + return success(); + } + + auto shape = selfType.getShape(); + + if (dimension < 0) { + dimension = toPositiveDim(dimension, selfRank); + } + if (!isValidDim(dimension, selfRank)) { + return rewriter.notifyMatchFailure(op, "dimension out of range"); + } + + Value dimSize = rewriter.create(loc, self, dimension); + + Value sizeValue = rewriter.create(loc, size); + Value sizeCheck = rewriter.create( + loc, arith::CmpIPredicate::ule, sizeValue, dimSize); + rewriter.create( + loc, sizeCheck, + rewriter.getStringAttr("size must be <= target dimension")); + + /* Calculate output shape of unfold op: + * https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html + * outputShape[dimension] is set to numBlocks, with size appended as an + * additional dimension + */ + SmallVector outputShape; + for (int64_t i = 0; i < selfRank; i++) { + if (i == dimension) { + outputShape.push_back(getDynamicOrStaticNumBlocks( + rewriter, loc, shape[dimension], dimSize, size, step)); + } else if (shape[i] == ShapedType::kDynamic) { + outputShape.push_back( + OpFoldResult(rewriter.create(loc, self, i))); + } else { + outputShape.push_back(rewriter.getIndexAttr(shape[i])); + } + } + outputShape.push_back(rewriter.getIndexAttr(size)); + + // Empty tensor to insert values into + Value outputTensor = rewriter.create( + loc, outputShape, selfType.getElementType()); + + /** + * Use reindexing to map output indices to input indices + * i.e. In output of rank 3 case: + * (i, j, k) => (i', j') where i' = i * step + k and j' = j + * if dimension == 0 + * (i, j, k) => (i', j') where i' = i and j' = j * step + k + * if dimension == 1 + */ + MLIRContext *context = rewriter.getContext(); + SmallVector outputExprs; + for (int dim = 0; dim < selfRank; ++dim) { + if (dim == dimension) { + auto idxLast = getAffineDimExpr(selfRank, context); + auto idxDimension = getAffineDimExpr(dimension, context); + + AffineExpr dimIdx = + idxLast + idxDimension * rewriter.getAffineConstantExpr(step); + outputExprs.push_back(dimIdx); + } else { + outputExprs.push_back(getAffineDimExpr(dim, context)); + } + } + + int64_t outputRank = selfRank + 1; + auto inputAffineMap = AffineMap::get(outputRank, 0, outputExprs, context); + auto outputAffineMap = + AffineMap::getMultiDimIdentityMap(outputRank, context); + + SmallVector iteratorTypes( + outputRank, utils::IteratorType::parallel); + + Value result = + rewriter + .create( + loc, outputTensor.getType(), self, outputTensor, + ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); + + rewriter.replaceOp(op, result); + return success(); + } + +private: + OpFoldResult getDynamicOrStaticNumBlocks(OpBuilder &rewriter, Location loc, + int64_t shapeDim, Value dimSize, + int64_t size, int64_t step) const { + /** + * numBlocks = (shape[dimension] - size) // step + 1 + */ + if (shapeDim == ShapedType::kDynamic) { + Value numBlocksSubOp = rewriter.create( + loc, dimSize, rewriter.create(loc, size)); + Value numBlocksDivOp = rewriter.create( + loc, numBlocksSubOp, + rewriter.create(loc, step)); + Value numBlocks = rewriter.create( + loc, rewriter.create(loc, 1), numBlocksDivOp); + return OpFoldResult(numBlocks); + } + + int64_t staticNumBlocks = (shapeDim - size) / step + 1; + return rewriter.getIndexAttr(staticNumBlocks); // Use static value + } +}; +} // namespace + namespace { class ConvertSparseOperatorOp : public OpConversionPattern { public: @@ -2606,9 +2717,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { return failure(); // Conversion is completed specified by information in the sparse tensor // type. Thus, we can rewrite all legalizedNames to the same construct. - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getOperands()[0]); return success(); @@ -2622,6 +2732,8 @@ class ConvertSparseOperatorOp : public OpConversionPattern { SmallVector ConvertSparseOperatorOp::legalizedNames = { "torch.aten._to_dense", "torch.aten._to_sparse", "torch.aten._to_csr", "torch.aten._to_csc", "torch.aten._to_bsr", "torch.aten._to_bsc", + "torch.aten.to_dense", "torch.aten.to_sparse", "torch.aten.to_csr", + "torch.aten.to_csc", "torch.aten.to_bsr", "torch.aten.to_bsc", }; } // namespace @@ -2656,7 +2768,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( /*benefit=*/200); patterns.add(typeConverter, context, /*benefit=*/100); - + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 9254b1a17ab7..07e4b23a167d 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -9,17 +9,14 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -225,23 +222,9 @@ class ConvertAtenEmbeddingBagPaddingIdxOp Value weight = adaptor.getWeight(); Value indices = adaptor.getIndices(); Value offsets = adaptor.getOffsets(); - Value scaleGradByFreq = op.getScaleGradByFreq(); Value mode = op.getMode(); - Value sparse = op.getSparse(); Value includeLastOffset = op.getIncludeLastOffset(); - bool scaleGradByFreqBool; - if (!matchPattern(scaleGradByFreq, - m_TorchConstantBool(&scaleGradByFreqBool))) { - return rewriter.notifyMatchFailure( - op, "scale_grad_by_freq is expected to be a constant boolean value."); - } - - if (scaleGradByFreqBool) { - return rewriter.notifyMatchFailure( - op, "Unimplemented: scale_grad_by_freq=True."); - } - int64_t modeInt; if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) { return rewriter.notifyMatchFailure( @@ -254,18 +237,6 @@ class ConvertAtenEmbeddingBagPaddingIdxOp "not supported yet for EmbeddingBag."); } - bool isSparse; - if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) { - return rewriter.notifyMatchFailure( - op, "sparse is expected to be a constant boolean value."); - } - - if (isSparse) { - return rewriter.notifyMatchFailure( - op, - "Unimplemented: Sparse mode is not supported yet for EmbeddingBag."); - } - bool discardLastOffset; if (!matchPattern(includeLastOffset, m_TorchConstantBool(&discardLastOffset))) { @@ -848,7 +819,7 @@ class ConvertAtenUpsampleNearest2dOp outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); - if (!op.getScalesH().getType().isa()) { + if (!isa(op.getScalesH().getType())) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) Value ceilVal = rewriter.create(loc, adaptor.getScalesH()); @@ -861,7 +832,7 @@ class ConvertAtenUpsampleNearest2dOp scaleFactorsInt.push_back(scaleFactorVal); } - if (!op.getScalesW().getType().isa()) { + if (!isa(op.getScalesW().getType())) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) Value ceilVal = rewriter.create(loc, adaptor.getScalesW()); @@ -1009,7 +980,7 @@ class ConvertAtenUpsampleNearest2dBackwardOp unsigned hDimOffset = 2; SmallVector scaleFactorsFloatValues; - if (!op.getScalesH().getType().isa()) { + if (!isa(op.getScalesH().getType())) { scaleFactorsFloatValues.push_back(adaptor.getScalesH()); } else { auto scaleFactorVal = rewriter.create( @@ -1022,7 +993,7 @@ class ConvertAtenUpsampleNearest2dBackwardOp scaleFactorsFloatValues.push_back(scaleFactorVal); } - if (!op.getScalesW().getType().isa()) { + if (!isa(op.getScalesW().getType())) { scaleFactorsFloatValues.push_back(adaptor.getScalesW()); } else { auto scaleFactorVal = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index c49646e2f1c0..5a10ebc29ca9 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -9,16 +9,14 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -44,7 +42,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, return; int64_t minSI = -(1 << (numBits - 1)); Value minSIValue = rewriter.create( - loc, minSI, zp.getType().cast().getWidth()); + loc, minSI, cast(zp.getType()).getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( @@ -149,12 +147,12 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } - auto resultTy = cast(op.getType()); - auto resultDTy = resultTy.toBuiltinTensor().getElementType(); - Type newResultType = getTypeConverter()->convertType(op.getType()); - Type elementType = cast(newResultType).getElementType(); - auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); - if (accumulatorDType != resultDTy) { + TensorType resultType = + cast(getTypeConverter()->convertType(op.getType())); + Type elementType = resultType.getElementType(); + auto accumulatorDType = + getDefaultAccType(rewriter, lhsType.getElementType()); + if (accumulatorDType != resultType.getElementType()) { elementType = accumulatorDType; } Value zeroFill = createZeroInitTensor( @@ -189,10 +187,10 @@ class ConvertAtenMmOp : public OpConversionPattern { ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill) .getResult(0); } else if (isUnsigned) { - matmul = rewriter - .create( - loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) - .getResult(0); + auto matmulOp = rewriter.create( + loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill); + matmulOp.setCast(linalg::TypeFn::cast_unsigned); + matmul = matmulOp->getResult(0); } else { matmul = rewriter .create(loc, zeroFill.getType(), @@ -200,18 +198,16 @@ class ConvertAtenMmOp : public OpConversionPattern { .getResult(0); } - if (accumulatorDType != resultDTy) { - Type resultElementType = - cast(newResultType).getElementType(); + if (accumulatorDType != resultType.getElementType()) { matmul = torch_to_linalg::convertTensorToElementType( - rewriter, loc, matmul, resultElementType); + rewriter, loc, matmul, resultType.getElementType()); } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result // of the MatmulOp will have this type too. So cast it to the desired type // so that in the end we have the original result type. - rewriter.replaceOpWithNewOp(op, newResultType, matmul); + rewriter.replaceOpWithNewOp(op, resultType, matmul); return success(); } @@ -227,14 +223,9 @@ class ConvertAtenFlipOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); auto selfRank = cast(adaptor.getSelf().getType()).getRank(); - Type elementType = - cast(adaptor.getSelf().getType()).getElementType(); - Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector axis; if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(axis))) @@ -247,40 +238,8 @@ class ConvertAtenFlipOp : public OpConversionPattern { } } - // Only used to calculate flipped values, i.e. those on the flip axes. Other - // dims won't be used. - SmallVector dims = getTensorSizes(rewriter, loc, self); - for (auto flipDim : axis) - dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); - - Value initTensor = createZeroInitTensor( - rewriter, loc, getTensorSizes(rewriter, loc, self), elementType); - - SmallVector iteratorTypes( - selfRank, utils::IteratorType::parallel); - SmallVector indexingMaps( - 2, AffineMap::getMultiDimIdentityMap(selfRank, context)); - Value flipped = - rewriter - .create( - loc, self.getType(), self, initTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (auto i = 0; i < selfRank; i++) - indices.push_back(b.create(loc, i)); - for (auto flipDim : axis) { - indices[flipDim] = b.create( - loc, dims[flipDim], indices[flipDim]); - } - Value res = b.create(loc, self, indices) - .getResult(); - b.create(loc, res); - }) - .getResult(0); - + Value flipped = torch_to_linalg::flipTensor(rewriter, loc, self, axis); rewriter.replaceOpWithNewOp(op, self.getType(), flipped); - return success(); } }; @@ -768,15 +727,21 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); + Type accumulatorDType = getDefaultAccType(rewriter, resultElementType); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, - resultElementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType); Value bmm = rewriter .create(loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0) .getResult(0); + + if (accumulatorDType != resultElementType) { + bmm = torch_to_linalg::convertTensorToElementType(rewriter, loc, bmm, + resultElementType); + } + rewriter.replaceOpWithNewOp(op, newResultType, bmm); return success(); } @@ -793,7 +758,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); Value input = adaptor.getInput(); /* in form of N*C*H*W */ - Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ + Value weight = adaptor.getWeight(); /* in form of F*C/G*H*W */ Value bias = adaptor.getBias(); auto resultTy = cast(op.getType()); @@ -809,6 +774,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); + inputZp = + rewriter.create(loc, rewriter.getI32Type(), inputZp); auto torchDtype = cast(make.getType()).getDtype(); inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } @@ -823,6 +790,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); + weightZp = rewriter.create(loc, rewriter.getI32Type(), + weightZp); auto torchDtype = cast(make.getType()).getDtype(); weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } @@ -832,7 +801,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "lhs and rhs of convolution must either be both int or fp"); } - if (inputZp && weightZp && !isa(bias.getType())) { + if (inputZp && !isa(bias.getType())) { auto biasDTy = cast(bias.getType()).getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( @@ -861,7 +830,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { - return rewriter.create(loc, intType, v); + return rewriter.createOrFold(loc, intType, v); }; SmallVector paddingIntValues; @@ -887,6 +856,48 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); + + // Adding support for 1d group convolution by converting the 1d-conv to + // 2d-conv. + // TODO: Replace this logic with the appropriate linalg op for 1-d group + // convolution once that support is added. + bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1); + if (is1DGroupConv) { + // Unsqueezing the last dim of input and weight. Also extending the + // dilation, stride, padding, and output padding lists. + auto unsqueezeInputInfo = + unsqueezeTensor(rewriter, op, input, /*dim=*/-1); + if (failed(unsqueezeInputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + input = unsqueezeInputInfo.value(); + + auto unsqueezeWeightInfo = + unsqueezeTensor(rewriter, op, weight, /*dim=*/-1); + if (failed(unsqueezeWeightInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + weight = unsqueezeWeightInfo.value(); + + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + paddingIntValues.push_back(cstZero); + outputPaddingIntValues.push_back(cstZero); + strideInts.push_back(1); + dilationInts.push_back(1); + + inRank++; + numSpatialDims++; + } + Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; @@ -898,13 +909,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); - // Checks for valid group size - int64_t groupSize; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groupSize))) - return rewriter.notifyMatchFailure(op, - "only constant group size supported."); - Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); - auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); @@ -1044,6 +1048,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { strideInts.clear(); strideInts.append(numSpatialDims, 1); } else { + if ((int64_t)paddingIntValues.size() + 2 != + cast(input.getType()).getRank()) { + // pytorch 2.5 generates one element padding = {0} for + // Conv2dWithValidPaddingModule + return rewriter.notifyMatchFailure(op, "unexpected number of padding"); + } // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); @@ -1055,15 +1065,15 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { castIndexToInt(weightDims[i]), strideIntValues[i])); } - Type accumulatorDType = getDefaultAccType(rewriter, resultDTy); + Type accumulatorDType = getDefaultAccType(rewriter, inputDTy); Value initTensor = rewriter.create( loc, getAsOpFoldResult(outDims), accumulatorDType); Value outputTensor; - if (accumulatorDType != resultDTy && !bias.getType().isa()) + if (accumulatorDType != resultDTy && !isa(bias.getType())) bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias, accumulatorDType); - if (bias.getType().isa()) { + if (isa(bias.getType())) { Value c0; if (isa(accumulatorDType)) { c0 = rewriter.create( @@ -1081,21 +1091,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); auto resultRank = cast(initTensor.getType()).getRank(); - SmallVector indexingMaps = { - // bias is used to initialize the channels - dimension 1 of output - AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0, - rewriter.getAffineDimExpr(1), context), - rewriter.getMultiDimIdentityMap(resultRank)}; - SmallVector iteratorTypes( - resultRank, utils::IteratorType::parallel); + SmallVector addedDimensions; + // bias is used to initialize the channels - dimension 1 of + // output + for (int i = 0; i < resultRank; ++i) + if (i != 1) + addedDimensions.push_back(i); outputTensor = rewriter - .create( - loc, initTensor.getType(), bias, initTensor, - indexingMaps, iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) - .getResult(0); + .create(loc, bias, initTensor, + addedDimensions) + ->getResult(0); } auto stridesAttr = rewriter.getI64VectorAttr(strideInts); @@ -1119,14 +1124,14 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value conv; // the code so far is able to respect all numSpatialDims - // the code below this point is numSpatialDims specific and groupSize + // the code below this point is numSpatialDims specific and numGroups // specific // TODO: factor out the above code into a helper function, and then separate // convolution into: // - grouped 1d-3d // - grouped 1d-3d (quantized) // - ungrouped 1d-3d - if (groupSize == 1 && !inputZp && !weightZp) { + if (numGroups == 1 && !inputZp) { switch (numSpatialDims) { case 1: conv = rewriter @@ -1167,55 +1172,58 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (groupSize == 1 && inputZp && weightZp) { - // The quantized version uses a different channel ordering so we need to - // permute the tensors in order to use the existing path. We should - // eventually directly support this channel ordering. - llvm::SmallVector inPerms, weightPerms; - inPerms.push_back(0); // N stays at the front for input. - // Then we expect the spatial dimensions - for (size_t i = 0; i < numSpatialDims; ++i) { - inPerms.push_back(i + 2); - weightPerms.push_back(i + 2); - } - inPerms.push_back(1); - weightPerms.append({1, 0}); - - paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); - weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); - outputTensor = - transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - + if (numGroups == 1 && inputZp) { switch (numSpatialDims) { case 2: conv = rewriter - .create( + .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); break; - case 3: + case 3: { + // The quantized version uses a different channel ordering so we need to + // permute the tensors in order to use the existing path. We should + // eventually directly support this channel ordering. + llvm::SmallVector inPerms, weightPerms; + inPerms.push_back(0); // N stays at the front for input. + // Then we expect the spatial dimensions + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 2); + } + inPerms.push_back(1); + weightPerms.append({1, 0}); + + paddedInput = + transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); + conv = rewriter .create( loc, outputTensor.getType(), ValueRange{paddedInput, weight, inputZp, weightZp}, outputTensor, stridesAttr, dilationAttr) .getResult(0); + + llvm::SmallVector outPerms; + outPerms.push_back(0); + outPerms.push_back(inPerms.size() - 1); + for (size_t i = 0; i < numSpatialDims; ++i) { + outPerms.push_back(i + 1); + } + conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); + break; + } default: return rewriter.notifyMatchFailure( op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; - llvm::SmallVector outPerms; - outPerms.push_back(0); - outPerms.push_back(inPerms.size() - 1); - for (size_t i = 0; i < numSpatialDims; ++i) { - outPerms.push_back(i + 1); - } - conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); - Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { Type resultElementType = @@ -1227,38 +1235,90 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return success(); } - if (inputZp || weightZp) - return rewriter.notifyMatchFailure( - op, "unimplemented: quantized grouped convolutions"); - - if (numSpatialDims != 2) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); - - // Special depthwise case + // Special depthwise case: Cin = Cout = groups. + // Note: pytorch considers Cin == groups (Cout possibly a non-zero multiple + // of groups) to be depthwise in their documentation, but the linalg ops + // apparently disagree. auto inShape = makeShapeTorchCompatible( cast(input.getType()).getShape()); auto weightShape = makeShapeTorchCompatible( cast(weight.getType()).getShape()); - if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && - weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { - // Collapse weight shape - SmallVector collapsedDims = {{0, 1}, {2}, {3}}; - SmallVector collapsedShape{ - (weightShape[0] == kUnknownSize ? kUnknownSize - : weightShape[0] * weightShape[1]), - weightShape[2], weightShape[3]}; + if (inShape[1] == numGroups && weightShape[0] == numGroups && + weightShape[1] == 1) { + // Collapse weight shape (C/G == 1) + SmallVector collapsedDims = {{0, 1}}; + SmallVector collapsedShape{weightShape[0] * weightShape[1]}; + for (unsigned i = 0; i < numSpatialDims; i++) { + collapsedDims.push_back({i + 2}); + collapsedShape.push_back(weightShape[i + 2]); + } Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), weightDTy); Value collapsedWeight = rewriter.create( loc, collapsedType, weight, collapsedDims); - - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + if (!inputZp) { + switch (numSpatialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D and 2D depthwise convolution " + "supported for special case of group convolution"); + }; + } else { + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D depthwise quantized convolution " + "supported for special case of group convolution"); + + // currently, the only named depthwise qconv op is nhwc_hwc + // input: nchw -> nhwc; weight (collapsed): chw -> hwc + // linalg conv result nhwc -> nchw + // inPerms = [0, 2, 3, 1] + // weightPerms = [1, 2, 0] + // resultPerms = [0, 3, 1, 2] + llvm::SmallVector inPerms, weightPerms, resultPerms; + inPerms.push_back(0); + resultPerms.append({0, static_cast(numSpatialDims + 1)}); + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 1); + resultPerms.push_back(i + 1); + } + inPerms.push_back(1); + weightPerms.push_back(0); + + paddedInput = + transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + collapsedWeight = + transposeValue(op.getLoc(), collapsedWeight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); + + conv = + rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + // convert output nhwc -> nchw + conv = transposeValue(op.getLoc(), conv, resultPerms, rewriter); + } Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { @@ -1267,10 +1327,25 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } + rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D and 2D grouped convolution supported"); + // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { auto inType = cast(tensor.getType()); @@ -1279,12 +1354,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector outShape; for (auto i = 0; i < (long)inShape.size(); i++) { if (i == 1) { - outShape.push_back(groupSize); + outShape.push_back(numGroups); } if (i == (long)dim) { outShape.push_back(inShape[i] == kUnknownSize ? kUnknownSize - : inShape[i] / groupSize); + : inShape[i] / numGroups); } else { outShape.push_back(inShape[i]); } @@ -1310,8 +1385,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto inShape = makeShapeTorchCompatible(inType.getShape()); SmallVector outShape{ - groupSize, - (inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / groupSize)}; + numGroups, + (inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / numGroups)}; outShape.append(inShape.begin() + 1, inShape.end()); SmallVector indices{{0, 1}}; @@ -1328,13 +1403,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto expandOutputTensor = expandGroups(outputTensor, 1); // TODO: add 1D and 3D case - conv = rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weightExpanded}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) - .getResult(0); - + if (!inputZp) { + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + } else { + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded, inputZp, + weightZp}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + } conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); @@ -1345,12 +1429,210 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeOutputInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate squeeze tensor"); + } + conv = squeezeOutputInfo.value(); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } }; } // namespace +namespace { + +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +Value getDFTMatmulCoeff(OpBuilder b, Location loc, + RankedTensorType matrixType) { + + ComplexType complexTy = llvm::cast(matrixType.getElementType()); + mlir::FloatType floatType = + llvm::cast(complexTy.getElementType()); + + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getDimSize(0); + + SmallVector> values; + for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { + for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { + double v = scale * i * j; + double realV = cos(v); + double imagV = -sin(v); + + bool unused; + APFloat real(realV); + real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + APFloat imag(imagV); + imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + + values.push_back(std::complex(real, imag)); + } + } + return b.create( + loc, matrixType, DenseElementsAttr::get(matrixType, values)); +} + +struct ConvertAtenFftRfftOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + RankedTensorType inputType = + cast(adaptor.getSelf().getType()); + if (!inputType.hasRank()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } + + const ArrayRef inputShape = inputType.getShape(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + if (fftLength == ShapedType::kDynamic) { + return rewriter.notifyMatchFailure( + op, "unsupported: FFT signal length must be static"); + } + const int64_t rank = inputType.getRank(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + // Transpose if FFT dimension is not the last one + llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); + std::swap(perms[dim], perms[lastDim]); + if (needTranspose) { + self = transposeValue(loc, self, perms, rewriter); + } + + RankedTensorType newResultType = llvm::cast( + getTypeConverter()->convertType(op.getType())); + ComplexType complexElemType = + llvm::cast(newResultType.getElementType()); + Type elemType = complexElemType.getElementType(); + + // coeffMatrix : tensor> + RankedTensorType coeffType = + RankedTensorType::get({fftLength, outputFftDim}, complexElemType); + // coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N) + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType); + + // #matmul_trait = { + // indexing_maps = [ + // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>, + // affine_map<(d_0, ... d_m, f, o) -> (f, o)>, + // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)> + // ], + // iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"] + // } + // linalg.generic #matmul_trait + // ins(%A, %B : tensor, + // tensor>) + // outs(%C : tensor>) { + // ^bb0(%a: f32, %b: complex, %c: complex) : + // %re = complex.re %b : f32 + // %im = complex.im %b : f32 + // %mulre = arith.mulf %a, %re: f32 + // %mulim = arith.mulf %a, %im: f32 + // %mulcplx = complex.create %mulre, %mulim : complex + // %add = complex.add %c, %mulcplx: complex + // linalg.yield %add : complex + // } -> (tensor>) + + Value lhs = self; + Value rhs = coeffMatrix; + RankedTensorType lhsType = llvm::cast(lhs.getType()); + ArrayRef lhsShape(lhsType.getShape()); + ArrayRef rhsShape(coeffType.getShape()); + + unsigned batchRank = lhsShape.size() - 1; + + SmallVector lhsExpr; + SmallVector rhsExpr; + SmallVector outExpr; + SmallVector iteratorTypes( + batchRank, utils::IteratorType::parallel); + SmallVector resultShape; + for (unsigned i = 0; i < batchRank; i++) { + lhsExpr.push_back(rewriter.getAffineDimExpr(i)); + outExpr.push_back(rewriter.getAffineDimExpr(i)); + resultShape.push_back(getDimOp(rewriter, loc, lhs, i)); + } + unsigned fIdx = batchRank, oIdx = batchRank + 1; + lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)}); + rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx), + rewriter.getAffineDimExpr(oIdx)}); + outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)}); + resultShape.insert(resultShape.end(), + {getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)}); + + Value zeroTensor = + createZeroInitTensor(rewriter, loc, resultShape, complexElemType); + auto indexingMaps = AffineMap::inferFromExprList( + {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); + iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction, + utils::IteratorType::parallel}); + + Value complexRes = + rewriter + .create( + loc, zeroTensor.getType(), + /*inputs=*/ValueRange{lhs, rhs}, + /*outputs=*/zeroTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], res = args[2]; + Value re = b.create(loc, elemType, r); + Value im = b.create(loc, elemType, r); + Value mulRe = b.create(loc, l, re); + Value mulIm = b.create(loc, l, im); + Value mulCplx = b.create( + loc, complexElemType, mulRe, mulIm); + Value add = b.create(loc, mulCplx, res); + b.create(loc, add); + }) + .getResult(0); + + // Transpose back + if (needTranspose) { + complexRes = transposeValue(loc, complexRes, perms, rewriter); + } + + rewriter.replaceOp(op, complexRes); + return success(); + } +}; + +} // namespace + void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1365,4 +1647,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 70b27fd84f24..90b5b2af77a8 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -9,18 +9,14 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -173,11 +169,42 @@ static LogicalResult createPoolingOp( Value windowTensor = rewriter.create( loc, getAsOpFoldResult(shape), elementType); - result = rewriter - .create(loc, outTensorInitialized.getType(), - ValueRange{paddedInput, windowTensor}, - outTensorInitialized, stridesAttr, dilationAttr) - .getResult(0); + Value permutedInput = paddedInput, permutedOutput = outTensorInitialized; + if (dimensionality == 3) { + // Permute input and output tensor as follows: + // (n,c,d,h,w) -> (n,d,h,w,c) + SmallVector dimensions = {0, 2, 3, 4, 1}; + if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(), + dimensions, paddedInput, + permutedInput))) + return rewriter.notifyMatchFailure( + op, "failed to perform permutation of tensor"); + + if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(), + dimensions, outTensorInitialized, + permutedOutput))) + return rewriter.notifyMatchFailure( + op, "failed to perform permutation of tensor"); + } + + Value poolingResult = + rewriter + .create(loc, permutedOutput.getType(), + ValueRange{permutedInput, windowTensor}, permutedOutput, + stridesAttr, dilationAttr) + .getResult(0); + + result = poolingResult; + if (dimensionality == 3) { + // Permute output tensor as follows: + // (n,d,h,w,c) -> (n,c,d,h,w) + SmallVector dimensions = {0, 4, 1, 2, 3}; + if (failed(torch_to_linalg::permuteTensor( + op, rewriter, op->getLoc(), dimensions, poolingResult, result))) + return rewriter.notifyMatchFailure( + op, "failed to perform permutation of tensor"); + } + return success(); } @@ -197,28 +224,41 @@ template <> struct DimensionTraits { static_assert(Dim == Dim); }; +template <> +struct DimensionTraits + : DimensionTraits {}; + template <> struct DimensionTraits { static constexpr int64_t Dim = 3; // unused const variable warning suppression: static_assert(Dim == Dim); }; +template <> +struct DimensionTraits + : DimensionTraits {}; + template class ConvertAtenMaxPoolOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + static const bool withIndices = + llvm::is_one_of::value; + private: static const int64_t Dim = DimensionTraits::Dim; - LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op, - typename OpTy::Adaptor adaptor, + LogicalResult createPoolingMax3D(OpTy &op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, SmallVectorImpl &kernelSizeIntValues, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, SmallVectorImpl &dilationInts, - bool ceilMode) const { - SmallVector outTensorShape; + bool ceilMode, + SmallVectorImpl &outTensorShape, + Value &paddedInput, Value &poolingOp) const { + static_assert(Dim == 3, "op must be MaxPool3d or MaxPool3dWithIndices"); Value self = adaptor.getSelf(); Type elementType = cast(self.getType()).getElementType(); TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( @@ -228,8 +268,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { Value initValue = rewriter.create(op->getLoc(), smallestFPValueAttr); - Value paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, - strideInts, paddingInts, initValue); + paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, strideInts, + paddingInts, initValue); auto outTensorInitialized = computeOutputTensor( op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts, @@ -282,25 +322,160 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { SmallVector(5, utils::IteratorType::parallel); iteratorTypes.append(3, utils::IteratorType::reduction); SmallVector indexingMaps = {mapInput, mapKernel, mapOutput}; - Value poolingOp = + poolingOp = rewriter + .create( + op->getLoc(), + /* result types */ outTensorInitialized.getType(), + /* operands */ ValueRange({paddedInput, windowTensor}), + /* outputs */ outTensorInitialized, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value currentVal = args[0], accMaxValue = args[2]; + Value max_result = b.create( + loc, currentVal, accMaxValue); + b.create(loc, max_result); + }) + .getResult(0); + + return success(); + } + + // Returns the corresponding indices of the input tensor for the max pooling + // result tensor. + // + // For finding the indices, we follow the below method: + // + // Take maxpool2d as an example to illustrate. Let's say the input tensor is a + // 4-d tensor. The maxpool2d and indices will also be a 4-d tensor. Then: + // for i in range(N): + // for j in range(C): + // for m in range(Hout): + // for n in range(Wout): + // for p in range(kH): + // for r in range(kW): + // indexH = m * stride[0] + p * dilation[0] + // indexW = n * stride[0] + r * dilation[0] + // if paddedInput[i, j, indexH, indexW] == + // maxPool2d[i, j, m, n]: + // indices[i, j, m, n] = + // (indexH - padding[0]) * W + + // (indexW - padding[1]) + // + LogicalResult + computeMaxPoolingIndices(Value maxPool, Value paddedInput, OpTy &op, + typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &outTensorShape, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, int64_t rank, + Value &indicesResult) const { + Location loc = op->getLoc(); + RankedTensorType indicesRankedTensorType = cast( + this->getTypeConverter()->convertType(op->getResult(1).getType())); + Value cstMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value indicesTensor = + createInitTensor(rewriter, loc, outTensorShape, + indicesRankedTensorType.getElementType(), cstMinusOne); + + SmallVector kernelSize = + castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); + SmallVector padding = + getAsConstantIndexValues(rewriter, loc, paddingInts); + SmallVector dilation = + getAsConstantIndexValues(rewriter, loc, dilationInts); + SmallVector kernelStride = + getAsConstantIndexValues(rewriter, loc, strideInts); + + Value windowTensor = rewriter.create( + loc, getAsOpFoldResult(kernelSize), + indicesRankedTensorType.getElementType()); + + SmallVector inputExprs, outputExprs, kernelExprs; + for (unsigned i = 0; i < rank; i++) { + inputExprs.push_back(rewriter.getAffineDimExpr(i)); + outputExprs.push_back(rewriter.getAffineDimExpr(i)); + } + for (unsigned i = 0; i < rank - 2; i++) { + kernelExprs.push_back(rewriter.getAffineDimExpr(i + rank)); + } + + // If computing indices for maxpool2d, we have six dimensions here. Each + // corresponding to N, C, Hout, Wout, kH, and kW, respectively, as described + // in the algorithm above. + SmallVector indexingMaps = AffineMap::inferFromExprList( + {inputExprs, kernelExprs, outputExprs}, rewriter.getContext()); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + iteratorTypes.append(rank - 2, utils::IteratorType::reduction); + + // Extract pooling dimensions of input shape. + SmallVector inputSubShape; + for (unsigned i = 0; i < rank - 2; i++) { + inputSubShape.push_back( + getDimOp(rewriter, loc, adaptor.getSelf(), i + 2)); + } + + indicesResult = rewriter .create( - op->getLoc(), - /* result types */ outTensorInitialized.getType(), - /* operands */ ValueRange({paddedInput, windowTensor}), - /* outputs */ outTensorInitialized, + loc, /*resultTensorTypes=*/indicesTensor.getType(), + /*inputs=*/ValueRange({maxPool, windowTensor}), + /*outputs=*/indicesTensor, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value currentVal = args[0], accMaxValue = args[2]; - Value max_result = - b.create(loc, currentVal, accMaxValue); - ; - b.create(loc, max_result); + Value maxVal = args[0], res = args[2]; + + SmallVector inputDims; + inputDims.append({b.create(loc, 0), + b.create(loc, 1)}); + for (unsigned i = 2; i < rank; i++) { + Value mainIndex = b.create(loc, i); + Value subIndex = + b.create(loc, i + rank - 2); + Value origin = b.create(loc, mainIndex, + kernelStride[i - 2]); + Value offset = + b.create(loc, subIndex, dilation[i - 2]); + inputDims.push_back( + b.create(loc, origin, offset)); + } + + Value input = + b.create(loc, paddedInput, inputDims); + Value pred = b.create( + loc, arith::CmpFPredicate::OEQ, input, maxVal); + + Value outIndex = + b.create(loc, b.getIndexAttr(0)); + Value curInputStride = + b.create(loc, b.getIndexAttr(1)); + for (unsigned i = 0; i < rank - 2; i++) { + Value minusPadding = b.create( + loc, inputDims[rank - 1 - i], padding[rank - 3 - i]); + Value timesStride = b.create( + loc, minusPadding, curInputStride); + outIndex = + b.create(loc, outIndex, timesStride); + curInputStride = b.create( + loc, curInputStride, inputSubShape[rank - 3 - i]); + } + Value result = b.create( + loc, pred, castIndexToInt64(b, loc, outIndex), res); + + Value predInvalidIndex = b.create( + loc, arith::CmpIPredicate::eq, res, cstMinusOne); + Value out = b.create(loc, predInvalidIndex, + result, res); + + b.create(loc, out); }) .getResult(0); - Type newResultType = this->getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, poolingOp); + return success(); } @@ -334,226 +509,321 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { Type elementType = cast(self.getType()).getElementType(); - if constexpr (Dim == 1) { - SmallVector outTensorShape; - Value maxPool1d, paddedInput; - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + TypedAttr smallestValueAttr; + + if (auto fpty = dyn_cast(elementType)) { + smallestValueAttr = rewriter.getFloatAttr( elementType, - APFloat::getInf( - cast(elementType).getFloatSemantics(), - /*Negative=*/true)); + APFloat::getInf(fpty.getFloatSemantics(), /*Negative=*/true)); + } else if (auto intTy = dyn_cast(elementType)) { + int64_t bw = intTy.getIntOrFloatBitWidth(); + smallestValueAttr = rewriter.getIntegerAttr( + elementType, intTy.isUnsigned() ? APInt::getMinValue(bw) + : APInt::getSignedMinValue(bw)); + } + + if (!smallestValueAttr) + return rewriter.notifyMatchFailure(op, "invalid element type"); + + // `maxPool` contains the result of maxpool 1d/2d/3d operation over the + // input, `paddedInput` means the padded result of input tensor. + Value maxPool, paddedInput; + Type maxPoolResultType = + typeConverter->convertType(op->getResult(0).getType()); + SmallVector outTensorShape; + if constexpr (Dim == 1) { if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, /*dimensionality=*/1, kernelSizeIntValues, strideInts, - paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, - paddedInput, maxPool1d))) + paddingInts, dilationInts, smallestValueAttr, outTensorShape, + paddedInput, maxPool))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool1d"); - Type newResultType = this->getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, maxPool1d); - return success(); } else if constexpr (Dim == 2) { - SmallVector outTensorShape; - // `maxpool2d` contains the result of maxpool2d operation over the input. - Value maxPool2d, paddedInput; - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf( - cast(elementType).getFloatSemantics(), - /*Negative=*/true)); if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, /*dimensionality=*/2, kernelSizeIntValues, strideInts, - paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, - paddedInput, maxPool2d))) + paddingInts, dilationInts, smallestValueAttr, outTensorShape, + paddedInput, maxPool))) return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); - Type newResultType = this->getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); - return success(); } else { - return createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues, - strideInts, paddingInts, dilationInts, - ceilMode); + if (failed(createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues, + strideInts, paddingInts, dilationInts, + ceilMode, outTensorShape, paddedInput, + maxPool))) + return rewriter.notifyMatchFailure(op, "unable to compute maxpool3d"); + } + + Value outMaxPool = rewriter.create( + op->getLoc(), maxPoolResultType, maxPool); + SmallVector outResult({outMaxPool}); + if (withIndices) { + Value indicesResult; + if (failed(computeMaxPoolingIndices( + maxPool, paddedInput, op, adaptor, rewriter, outTensorShape, + kernelSizeIntValues, strideInts, paddingInts, dilationInts, + selfRank, indicesResult))) + return rewriter.notifyMatchFailure(op, + "unable to compute maxpool indices"); + Type indicesResultType = + typeConverter->convertType(op->getResult(1).getType()); + Value outIndices = rewriter.create( + op->getLoc(), indicesResultType, indicesResult); + outResult.push_back(outIndices); } + rewriter.replaceOp(op, outResult); + + return success(); } }; } // namespace namespace { -// Returns the result of maxpool2d over the input tensor. And the corresponding -// indices of the input tensor for the values of the result tensor. -// -// The result of the maxpool2d operation is calculated using the helper function -// written above. For finding the indices, we follow the below method: -// -// Let's say the input tensor is a 4-d tensor. The maxpool2d and indices will -// also be a 4-d tensor. Then: -// for i in range(N): -// for j in range(C): -// for m in range(Hout): -// for n in range(Wout): -// for p in range(kH): -// for r in range(kW): -// indexH = m * stride[0] + p * dilation[0] -// indexW = n * stride[0] + r * dilation[0] -// if paddedInput[i, j, indexH, indexW] == -// maxPool2d[i, j, m, n]: -// indices[i, j, m, n] = (indexH - padding[0]) * W + -// (indexW - padding[1]) -// -class ConvertAtenMaxPool2dWithIndicesOp - : public OpConversionPattern { +// Max unpooling operation, takes result of max_pooling op and indices and +// tries to reconstructs original pooling input by filling out values by either +// values from self or zero. +// Upstream CPU implementation use parallel loop over the indices array to fill +// out tensor but such approach requires random access writes, which is tricky +// to represent in linalg. +// Instead we are using a different method: we are mapping each input/index +// value to multiple output values via affine maps in linalg.generic, then, +// inside the body of generic, we compute out index and compare it with expected +// index we got from input, returning either input or zero. +// This method only works if we have non-overlapping pooling windows. +// In case of overlap (e.g. kernel_size=2, stride=1) we need to map many-to-many +// input to output values and do a reduction. To construct such mapping we need +// to know original Kernel size, but it doesn't encoded in aten op. We cannot +// reconstruct kernel_size either as such reconstruction is ambiguous (e.g. for +// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3). +// What worse, without knowing kernel size we cannot even reliably detect such +// cases and this conversion will just return invalid values. +class ConvertAtenMaxUnpool3dOp final + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor, + matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + Location loc = op->getLoc(); const TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); - RankedTensorType selfType = cast(self.getType()); - Type elementType = selfType.getElementType(); - RankedTensorType indicesRankedTensorType = - getTypeConverter() - ->convertType(op->getResult(1).getType()) - .cast(); - - // TODO: Add support for 3D inputs. - if (selfType.getRank() == 3) - return rewriter.notifyMatchFailure( - op, "unimplemented: only support 4D input"); + auto selfType = cast(self.getType()); - bool ceilMode; - SmallVector kernelSizeIntValues; - SmallVector strideInts, paddingInts, dilationInts; - if (!matchPattern(op.getDilation(), - m_TorchListOfConstantInts(dilationInts))) + ArrayRef inputSize = selfType.getShape().take_back(3); + if (ShapedType::isDynamicShape(inputSize)) return rewriter.notifyMatchFailure(op, - "only support constant int dilations"); - if (failed(checkAndGetPoolingParameters( - op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, - strideInts, paddingInts))) - return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); + "input type must be of static shape"); - // `maxpool2d` contains the result of maxpool2d operation over the input. - auto smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf(cast(elementType).getFloatSemantics(), - /*Negative=*/true)); - Value maxPool2d, paddedInput; - SmallVector outTensorShape; - if (failed(createPoolingOp( - op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, - /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, - dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, - maxPool2d))) - return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); + Value indices = adaptor.getIndices(); + auto indicesType = cast(indices.getType()); + if (inputSize != indicesType.getShape().take_back(3)) + return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); - Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); - Value indicesTensor = - createInitTensor(rewriter, loc, outTensorShape, - indicesRankedTensorType.getElementType(), cstMinusOne); + auto resType = typeConverter->convertType(op.getType()); + if (!resType) + return rewriter.notifyMatchFailure(op, "invalid result type"); - SmallVector kernelSize = - castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); - SmallVector padding = - getAsConstantIndexValues(rewriter, loc, paddingInts); - SmallVector dilation = - getAsConstantIndexValues(rewriter, loc, dilationInts); - SmallVector stride = - getAsConstantIndexValues(rewriter, loc, strideInts); + ArrayRef inferredOutSize = resType.getShape().take_back(3); + if (ShapedType::isDynamicShape(inferredOutSize)) + return rewriter.notifyMatchFailure(op, + "output type must be of static shape"); - Value windowTensor = rewriter.create( - loc, getAsOpFoldResult(kernelSize), - indicesRankedTensorType.getElementType()); + { + SmallVector output; + if (!matchPattern(op.getOutputSize(), m_TorchListOfConstantInts(output))) + return rewriter.notifyMatchFailure(op, + "only support constant int output"); - SmallVector inputExprs, outputExprs, kernelExprs; - for (unsigned i = 0; i < 4; i++) { - inputExprs.push_back(rewriter.getAffineDimExpr(i)); - outputExprs.push_back(rewriter.getAffineDimExpr(i)); + if (inferredOutSize != ArrayRef(output)) + return rewriter.notifyMatchFailure(op, "Invalid output size"); } - kernelExprs.push_back(rewriter.getAffineDimExpr(4)); - kernelExprs.push_back(rewriter.getAffineDimExpr(5)); + SmallVector stride; + SmallVector padding; - // Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH, - // and kW, respectively, as described in the algorithm above. - SmallVector indexingMaps = AffineMap::inferFromExprList( - {inputExprs, kernelExprs, outputExprs}, rewriter.getContext()); - SmallVector iteratorTypes( - 4, utils::IteratorType::parallel); - iteratorTypes.push_back(utils::IteratorType::reduction); - iteratorTypes.push_back(utils::IteratorType::reduction); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride))) + return rewriter.notifyMatchFailure(op, + "only support constant int strides"); - // Input format is : [N, C, H, W] - Value inputShapeW = getDimOp(rewriter, loc, self, 3); + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding))) + return rewriter.notifyMatchFailure(op, + "only support constant int padding"); - Value indicesResult = - rewriter - .create( - loc, /*resultTensorTypes=*/indicesTensor.getType(), - /*inputs=*/ValueRange({maxPool2d, windowTensor}), - /*outputs=*/indicesTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value maxVal = args[0], res = args[2]; + // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool" + // (padding.size() == 6). + if (stride.size() != 3 || padding.size() != 3) + return rewriter.notifyMatchFailure( + op, "stride and padding must be of size 3"); + + int64_t outRank = resType.getRank(); + int64_t NC = outRank - 3; + + for (auto &&[inDim, outDim, str, pad] : + llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { + // Kernel size computation is ambiguous, this formula will return the + // biggest possible kernel size. As there is no way to know actual kernel + // size we have to treat it conservatively and always bail if kernel size + // potentially bigger than stride. + int64_t kernelSize = outDim - (inDim - 1) * str + 2 * pad; + if (kernelSize > str) + return rewriter.notifyMatchFailure( + op, "potential pooling windows overlapping is detected, this case " + "is not supported yet"); + } - Value i = b.create(loc, 0); - Value j = b.create(loc, 1); - Value m = b.create(loc, 2); - Value n = b.create(loc, 3); - Value p = b.create(loc, 4); - Value r = b.create(loc, 5); - - Value mTimesStride = - b.create(loc, m, stride[0]); - Value pTimesDilation = - b.create(loc, p, dilation[0]); - Value indexH = b.create(loc, mTimesStride, - pTimesDilation); - Value nTimesStride = - b.create(loc, n, stride[1]); - Value rTimesDilation = - b.create(loc, r, dilation[1]); - Value indexW = b.create(loc, nTimesStride, - rTimesDilation); - Value input = b.create( - loc, paddedInput, ValueRange{i, j, indexH, indexW}); - Value pred = b.create( - loc, arith::CmpFPredicate::OEQ, input, maxVal); + Type indexType = rewriter.getIndexType(); + SmallVector outSizePadded; + for (auto &&[i, size] : llvm::enumerate(resType.getShape())) { + if (int64_t(i) < NC) { + outSizePadded.emplace_back( + rewriter.create(loc, self, i)); + continue; + } + int64_t pad = padding[i - NC]; + + outSizePadded.emplace_back( + rewriter.create(loc, size + pad)); + } - Value indexHMinusPadding = - b.create(loc, indexH, padding[0]); - Value indexWMinusPadding = - b.create(loc, indexW, padding[1]); - Value outIndex = b.create( - loc, indexHMinusPadding, inputShapeW); - outIndex = b.create(loc, outIndex, - indexWMinusPadding); - Value result = b.create( - loc, pred, castIndexToInt64(b, loc, outIndex), res); + auto ceilDiv = [](int64_t v1, int64_t v2) -> int64_t { + return (v1 + v2 - 1) / v2; + }; + + // In case if input tensor size is not divisible by stride + // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) + // pad self and indices tensors to avoid out of bounds access. + SmallVector expectedInputShape = + llvm::to_vector(resType.getShape().drop_back(3)); + for (auto &&[str, pad, resSize] : + llvm::zip_equal(stride, padding, inferredOutSize)) + expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); + + if (expectedInputShape != selfType.getShape()) { + // TODO: this is probably expensive, and it may be possible to solve by + // cleverly constructing affine maps for the next linalg.generic op, + // but I'm not smart enough to figure this out. + + SmallVector low(outRank, 0); + SmallVector high(NC, 0); + for (auto &&[inpSize, outSize] : llvm::zip_equal( + inputSize, ArrayRef(expectedInputShape).take_back(3))) { + high.emplace_back(outSize - inpSize); + } + + // Pad the indices tensor with a value which cannot appear in real data + // (-1) so it will never match. In this case we can pad self with any + // value, as it will never affect the output. + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(selfType.getElementType())); + Value invalidIdx = rewriter.create( + loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1)); + self = + torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero); + indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, + high, invalidIdx); + } - Value predInvalidIndex = b.create( - loc, arith::CmpIPredicate::eq, res, cstMinusOne); - Value out = b.create(loc, predInvalidIndex, - result, res); + Value init = rewriter.create( + loc, getAsOpFoldResult(outSizePadded), selfType.getElementType()); + + SmallVector inputExprs; + SmallVector outputExprs; + for (auto i : llvm::seq(0, outRank)) { + AffineExpr dim = rewriter.getAffineDimExpr(i); + if (i < NC) { + inputExprs.emplace_back(dim); + } else { + int64_t j = i - NC; + inputExprs.emplace_back(dim.floorDiv(stride[j])); + } + outputExprs.emplace_back(dim); + } - b.create(loc, out); - }) + SmallVector indexingMaps = AffineMap::inferFromExprList( + {inputExprs, inputExprs, outputExprs}, rewriter.getContext()); + + SmallVector iteratorTypes( + outRank, utils::IteratorType::parallel); + + auto computeIndex = [&](OpBuilder &b, Location loc) -> Value { + // Next linalg.generic uses identity mapping for the unpooled tensor, + // compute linear index for output element, which we will the compare with + // values which came from indices tensor. + Value ret; + for (auto i : llvm::seq(NC, outRank)) { + Value idx = b.create(loc, i); + // If pool input was padded, adjust indices so they start at 0 in the + // non-padded area. Indices outside non-padded area will make no sense, + // but it doesnt matter as we will cut the padded area later by + // extract_slice. + int64_t pad = padding[i - NC]; + if (pad != 0) { + Value padVal = b.create(loc, pad); + idx = b.create(loc, idx, padVal); + } + + if (!ret) { + ret = idx; + } else { + Value size = + b.create(loc, resType.getShape()[i]); + ret = b.create(loc, ret, size); + ret = b.create(loc, ret, idx); + } + } + return ret; + }; + + auto builder = [&](OpBuilder &b, Location loc, ValueRange args) { + // Compute current output linear index and compare it with the value + // from indices arg. + Value input = args[0]; + Value zero = b.create( + loc, rewriter.getZeroAttr(input.getType())); + Value index = b.create(loc, indexType, args[1]); + Value currentIndex = computeIndex(b, loc); + Value cmp = b.create(loc, arith::CmpIPredicate::eq, index, + currentIndex); + Value out = b.create(loc, cmp, input, zero); + b.create(loc, out); + }; + + Value result = + rewriter + .create(loc, + /*resultTensorTypes=*/init.getType(), + /*inputs=*/ValueRange({self, indices}), + /*outputs=*/init, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, builder) .getResult(0); - Type maxPool2dResultType = - getTypeConverter()->convertType(op->getResult(0).getType()); - Type indicesResultType = - getTypeConverter()->convertType(op->getResult(1).getType()); - Value outMaxpool2d = - rewriter.create(loc, maxPool2dResultType, maxPool2d); - Value outIndices = - rewriter.create(loc, indicesResultType, indicesResult); + if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) { + // MaxPool input was padded, unpad it by taking the slice. + SmallVector offsetVals(NC, rewriter.getI64IntegerAttr(0)); + for (int64_t pad : padding) + offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad)); + + SmallVector sizeVals; + for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) { + if (!ShapedType::isDynamic(dim)) { + sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim)); + continue; + } + + sizeVals.emplace_back(rewriter.create(loc, self, i)); + } + SmallVector stridesVals(outRank, + rewriter.getI64IntegerAttr(1)); + result = rewriter.create(loc, result, offsetVals, + sizeVals, stridesVals); + } + + if (result.getType() != resType) + result = rewriter.create(loc, resType, result); - rewriter.replaceOp(op, {outMaxpool2d, outIndices}); + rewriter.replaceOp(op, result); return success(); } }; @@ -588,6 +858,16 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); + // Decode strideInts into strideInts and dilation + if (strideInts.size() == 2 * Dim) { + for (int i = 0; i < Dim; i++) { + dilationInts[i] = strideInts[Dim + i]; + } + for (int i = 0; i < Dim; i++) { + strideInts.pop_back(); + } + } + // TODO: Add support for count_include_pad equal to `False`. bool countIncludePad; if (!matchPattern(op.getCountIncludePad(), @@ -595,13 +875,6 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); - // If the padding is zero then there is no padding to include. - if (!countIncludePad && - !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { - return rewriter.notifyMatchFailure( - op, "unimplemented: count_include_pad is expected to be true"); - } - // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; SmallVector outTensorShape; @@ -611,40 +884,166 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); - Value divisor; - if constexpr (std::is_same()) { - Value kHtimeskW = rewriter.create( - loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); - divisor = isa(op.getDivisorOverride().getType()) - ? kHtimeskW - : adaptor.getDivisorOverride(); - } else { - divisor = kernelSizeIntValues[0]; - } - divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); + // Compute the average of sumPool. Value outputTensor = rewriter.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); SmallVector indexingMapsAvg( 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( Dim + 2, utils::IteratorType::parallel); - Value avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) - .getResult(0); + Value avgPool; + Value divisor; + // Case1: AtenAvgPool1d/2dOp with countIncludePad=false support. + if constexpr (std::is_same()) { + auto selfType = cast(self.getType()); + const int64_t selfRank = selfType.getRank(); + int64_t wDim = toPositiveDim(-1, selfRank); + int64_t hDim = toPositiveDim(-2, selfRank); + Value inputHeight = getDimOp(rewriter, loc, self, hDim); + Value inputWidth = getDimOp(rewriter, loc, self, wDim); + RankedTensorType sumPoolType = cast(sumPool.getType()); + const int64_t rank = sumPoolType.getRank(); + int dimH = toPositiveDim(-2, rank); + int dimW = toPositiveDim(-1, rank); + avgPool = + rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + // The algorithm for computing the divisor with + // count_include_pad is manily based on pytorch + // implementation. The following code is comment + // with pytorch code. + // https://github.com/pytorch/pytorch/blob/4a6dfbe4806b361c43210dfd56db64c4097c66bb/aten/src/ATen/native/cpu/AvgPoolKernel.cpp#L78 + Value indexOh = + b.create(loc, /*value=*/dimH); + Value oh = castIndexToInt64(b, loc, indexOh); + Value indexOw = + b.create(loc, /*value=*/dimW); + Value ow = castIndexToInt64(b, loc, indexOw); + + // int64_t ih0 = oh * dH - padH; + Value dH = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[0])); + Value padH = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[0])); + Value ohDH = b.create(loc, oh, dH); + Value ih0 = b.create(loc, ohDH, padH); + // int64_t iw0 = ow * dW - padW; + Value dW = rewriter.create( + loc, rewriter.getI64IntegerAttr(strideInts[1])); + Value padW = rewriter.create( + loc, rewriter.getI64IntegerAttr(paddingInts[1])); + Value owDW = b.create(loc, ow, dW); + Value iw0 = b.create(loc, owDW, padW); + // int64_t ih1 = std::min(ih0 + kH, input_height + padH); + Value ih = castIndexToInt64(b, loc, inputHeight); + Value ih0KH = b.create( + loc, ih0, kernelSizeIntValues[0]); + Value ihPadH = b.create(loc, ih, padH); + Value ih1 = b.create(loc, ih0KH, ihPadH); + // int64_t iw1 = std::min(iw0 + kW, input_width + padW); + Value iw = castIndexToInt64(b, loc, inputWidth); + Value iw0KW = b.create( + loc, iw0, kernelSizeIntValues[1]); + Value iwPadW = b.create(loc, iw, padW); + Value iw1 = b.create(loc, iw0KW, iwPadW); + // int64_t pool_size = (ih1 - ih0) * (iw1 - iw0); + Value ih1Ih0 = b.create(loc, ih1, ih0); + Value iw1Iw0 = b.create(loc, iw1, iw0); + Value poolSize = + b.create(loc, ih1Ih0, iw1Iw0); + // ih0 = std::max(ih0, 0); + Value cstZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value ih0Clamped = + b.create(loc, ih0, cstZero); + // iw0 = std::max(iw0, 0); + Value iw0Clamped = + b.create(loc, iw0, cstZero); + // ih1 = std::min(ih1, input_height); + Value ih1Clamped = b.create(loc, ih1, ih); + // iw1 = std::min(iw1, input_width); + Value iw1Clamped = b.create(loc, iw1, iw); + // if (divisor_override.has_value()) { + // divisor = divisor_override.value(); + // } else { + // if(count_include_pad) { + // divisor = pool_size; + // } else { + // divisor = (ih1 - ih0) * (iw1 - iw0); + // } + // } + if (countIncludePad) { + divisor = convertScalarToDtype(b, loc, poolSize, + resultElementType); + } else { + Value ih1_ih0 = + b.create(loc, ih1Clamped, ih0Clamped); + Value iw1_iw0 = + b.create(loc, iw1Clamped, iw0Clamped); + divisor = b.create(loc, ih1_ih0, iw1_iw0); + } + // AtenAvgPool2/3dOp has an optional divisor_override + // attribute while AtenAvgPool1dOp does not. + if constexpr (std::is_same()) { + if (!isa( + op.getDivisorOverride().getType())) + divisor = adaptor.getDivisorOverride(); + } + + divisor = convertScalarToDtype(b, loc, divisor, + resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, avgPool); + return success(); + } + // TODO: Add support for count_include_pad equal to `False` in + // AtenAvgPool1/3dOp. + if (!countIncludePad && + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { + return rewriter.notifyMatchFailure( + op, "unimplemented: count_include_pad is expected to be true for " + "AtenAvgPool3dOp"); + } + + // Case2: AtenAvgPool1/3dOp without count_include_pad equal to `False`. + divisor = kernelSizeIntValues[0]; + for (uint32_t i = 1; i < kernelSizeIntValues.size(); i++) { + divisor = + rewriter.create(loc, divisor, kernelSizeIntValues[i]); + } + if constexpr (!std::is_same()) { + divisor = isa(op.getDivisorOverride().getType()) + ? divisor + : adaptor.getDivisorOverride(); + } + divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); + avgPool = rewriter + .create( + loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.create(loc, args[0], divisor); + b.create(loc, avg); + }) + .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); } @@ -722,10 +1121,10 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { Location loc = op->getLoc(); const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); - outputType = typeConverter->convertType(op.getResult0().getType()) - .template cast(); - auxTensorType = typeConverter->convertType(op.getResult1().getType()) - .template cast(); + outputType = cast( + typeConverter->convertType(op.getResult0().getType())); + auxTensorType = cast( + typeConverter->convertType(op.getResult1().getType())); Type auxTensorElementType = auxTensorType.getElementType(); auto smallestFPValueAttr = rewriter.getFloatAttr( elementType, @@ -804,8 +1203,8 @@ class AdaptiveAvgPoolingHelper : public AdaptivePoolingHelper { Location loc = op->getLoc(); const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); - outputType = typeConverter->convertType(op.getResult().getType()) - .template cast(); + outputType = cast( + typeConverter->convertType(op.getResult().getType())); buffVal = rewriter.create( loc, elementType, rewriter.getFloatAttr(elementType, 0)); auxTensor = rewriter.create( @@ -1121,14 +1520,25 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( patterns.add>(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + + target.addIllegalOp(); + patterns.add(typeConverter, context); + + target.addIllegalOp(); patterns .add>( typeConverter, context); patterns .add>( typeConverter, context); + patterns + .add>( + typeConverter, context); target.addIllegalOp(); patterns.add>( diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 3a0b81f5a10a..854e3f86d367 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -9,19 +9,15 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; @@ -47,9 +43,8 @@ class ConvertAtenDropoutOp : public OpConversionPattern { if (train) return failure(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getInput()); return success(); @@ -65,8 +60,8 @@ static Value toLinearIndex(OpBuilder &b, Location loc, Value result = b.create(loc, b.getZeroAttr(b.getI64Type())); for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) { - assert(index.getType().isa() && - stride.getType().isa() && + assert(isa(index.getType()) && + isa(stride.getType()) && "Input arrays to `toLinearIndex` must only contain values of type " "`mlir::IntegerType`"); Value mul = b.create(loc, result, stride); @@ -113,6 +108,25 @@ static Value randomUniformUInt(OpBuilder &b, Location loc, Value ctr, return bitwiseXOr(t, shiftRight32(add(mul(x, x), y))); } +// generate uniform random Float64 +static Value randomUniformF64(OpBuilder &b, Location loc, Value ctr, Value key, + Value min, Value max) { + Value randomVal = randomUniformUInt(b, loc, ctr, key); + // scale = (max - min) * const(F64, 5.4210108E-20) + // which is derived from rand(min,max) = + // rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1 + Value epsilon = b.create( + loc, b.getFloatAttr(b.getF64Type(), 5.4210108E-20)); + Value range = b.create(loc, max, min); + Value scale = b.create(loc, range, epsilon); + // res = cast(F64, tempN) * scale + min + Value updateFloat = b.create(loc, b.getF64Type(), randomVal); + Value updateScaled = b.create(loc, updateFloat, scale); + Value uniformSample = b.create(loc, updateScaled, min); + + return uniformSample; +} + namespace { class ConvertAtenUniformOp : public OpConversionPattern { public: @@ -134,7 +148,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { if (!isa(elemTy)) return rewriter.notifyMatchFailure(op, "This op only support float type"); - if (!generator.getType().isa()) + if (!isa(generator.getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -168,24 +182,11 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value linearIndex = toLinearIndex(b, loc, indicesIntValues, sizesIntValues); - Value randomVal = randomUniformUInt(b, loc, linearIndex, key); - - // scale = (max - min) * const(F64, 5.4210108E-20) - // which is derived from rand(min,max) = - // rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1 - Value epsilon = b.create( - loc, b.getFloatAttr(min.getType(), 5.4210108E-20)); - Value range = b.create(loc, max, min); - Value scale = b.create(loc, range, epsilon); - - // res = cast(F64, tempN) * scale + min - Value updateFloat = - b.create(loc, f64Ty, randomVal); - Value updateScaled = - b.create(loc, updateFloat, scale); - Value res = b.create(loc, updateScaled, min); + + Value res = + randomUniformF64(b, loc, linearIndex, key, min, max); Value truncRes = res; - if (elemTy.isa()) + if (isa(elemTy)) truncRes = b.create(loc, elemTy, res); b.create(loc, truncRes); }) @@ -198,6 +199,323 @@ class ConvertAtenUniformOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenMultinomialOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenMultinomialOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op.getLoc(); + Value self = adaptor.getSelf(); + Value numSamples = adaptor.getNumSamples(); + Value generator = adaptor.getGenerator(); + RankedTensorType selfType = cast(self.getType()); + Type elemTy = selfType.getElementType(); + Type f64Ty = rewriter.getF64Type(); + Type i64Ty = rewriter.getI64Type(); + Type indexTy = rewriter.getIndexType(); + int64_t inputRank = selfType.getRank(); + bool bReplacement; + + if (!isa(elemTy)) + return rewriter.notifyMatchFailure(op, "This op only support float type"); + + if (!mlir::isa(generator.getType())) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + if (!matchPattern(op.getReplacement(), m_TorchConstantBool(&bReplacement))) + return rewriter.notifyMatchFailure( + op, "Unsupported: replacement must be a boolean value"); + + if (!bReplacement) + return rewriter.notifyMatchFailure(op, + "Unimplemented: replacement = False"); + + if (!mlir::isa(numSamples.getType())) { + return rewriter.notifyMatchFailure( + op, "Unsupported: num_samples must be an integer value"); + } + + if (!(inputRank == 1 || inputRank == 2)) { + return rewriter.notifyMatchFailure( + op, "torch.multinomial accepts only rank 1 or 2 tensors as weights"); + } + + Value cstZero = rewriter.create( + loc, i64Ty, rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + loc, i64Ty, rewriter.getI64IntegerAttr(1)); + Value zeroIndex = rewriter.create(loc, 0); + Value oneIndex = rewriter.create(loc, 1); + Value numSamplesIndex = + rewriter.create(loc, indexTy, numSamples); + + Value numDistributions; + Value numCategoriesIndex; + ValueRange resultShape; + if (inputRank == 1) { + numDistributions = cstOne; + numCategoriesIndex = + rewriter.create(loc, indexTy, self, zeroIndex); + resultShape = ValueRange{numSamplesIndex}; + } else { + Value numDistIndex = + rewriter.create(loc, indexTy, self, zeroIndex); + numCategoriesIndex = + rewriter.create(loc, indexTy, self, oneIndex); + numDistributions = + rewriter.create(loc, i64Ty, numDistIndex); + resultShape = ValueRange{numDistIndex, numSamplesIndex}; + } + + Value numCategories = + rewriter.create(loc, i64Ty, numCategoriesIndex); + Value resultTensor = rewriter.create( + loc, getAsOpFoldResult(resultShape), i64Ty); + + // sum weights for normalization + torch_to_linalg::ReductionOpInfo opInfo; + if (inputRank == 1) + opInfo = {false, self, {0}}; + else + opInfo = {false, self, {1}}; + + Value initSum = rewriter.create( + loc, f64Ty, rewriter.getF64FloatAttr(0.0)); + int64_t srcWidth = cast(elemTy).getWidth(); + if (srcWidth > 64) + op->emitWarning("Op bitwidth will be truncated from " + + std::to_string(srcWidth) + " bits to 64 bits."); + auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value input = payloadArgs[0]; + if (srcWidth < 64) + input = b.create(loc, f64Ty, input); + if (srcWidth > 64) + input = b.create(loc, f64Ty, input); + Value result = payloadArgs[1]; + Value nextSum = b.create(loc, input, result); + b.create(loc, nextSum); + }; + Value sumWeights = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, opInfo, initSum, sumBody); + + // Get multinomial samples for each weight vector + auto multinomialComputation = [&](OpBuilder &b, Location loc, Value j, + ValueRange args) { + Value jIndex = b.create(loc, indexTy, j); + + Value sum; + if (inputRank == 1) { + sum = b.create(loc, sumWeights, ValueRange{}); + } else { + sum = b.create(loc, sumWeights, ValueRange{jIndex}); + } + + // compute cdf in loop + Value initCdf = b.create( + loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty); + Value cdf = + b.create( + loc, cstZero, numCategories, cstOne, ValueRange{initCdf}, + [&](OpBuilder &b, Location loc, Value i, ValueRange vals) { + Value distribution = vals[0]; + // if (i > 0) + auto comparisonPredicate = arith::CmpIPredicateAttr::get( + b.getContext(), arith::CmpIPredicate::sgt); + Value condition = b.create( + loc, comparisonPredicate, i, cstZero); + Value iIndex = b.create(loc, indexTy, i); + // curr_cum = i > 0 ? prob[i] + prob[i-1] : prob[i] + ValueRange ind; + if (inputRank == 1) { + ind = ValueRange{iIndex}; + } else { + ind = ValueRange{jIndex, iIndex}; + } + Value currWeight = b.create(loc, self, ind); + if (srcWidth < 64) + currWeight = b.create(loc, f64Ty, currWeight); + if (srcWidth > 64) + currWeight = + b.create(loc, f64Ty, currWeight); + Value currMass = b.create(loc, currWeight, sum); + Value currCum = + b.create( + loc, condition, + [&](OpBuilder &b, Location loc) { + Value prevI = + b.create(loc, i, cstOne); + Value prevIndex = b.create( + loc, indexTy, prevI); + Value prevMass = b.create( + loc, distribution, ValueRange{prevIndex}); + Value currSum = b.create( + loc, currMass, prevMass); + b.create(loc, ValueRange(currSum)); + }, + [&](OpBuilder &b, Location loc) { + b.create(loc, ValueRange{currMass}); + }) + .getResult(0); + + Value updatedCdf = b.create( + loc, currCum, distribution, ValueRange(iIndex)); + b.create(loc, ValueRange(updatedCdf)); + }) + .getResult(0); + + /* + * Above we've computed the CDF for the unnormalized distribution given to + * us by the user. In order to actually sample from this distribution we + * do the following below: 1) Sample a random floating point value, r in + * [0,1), from a uniform distribution. 2) Perform a binary search in the + * cdf to find the first bin in the CDF where cdf[i] < r. This guarantees + * a random sample from the provided distribution with the appropriate + * probabilities. + * + * This logic is pulled straight from PyTorch's Multinomial Kernel: + * https://github.com/pytorch/pytorch/blob/e4623de4cf6097ff399aa9eb0cef44b44ca76da4/aten/src/ATen/native/cpu/MultinomialKernel.cpp#L23 + * */ + + // Get key, min and max used by RNG. + Value key = b.create(loc); + Value min = b.create(loc, f64Ty, + rewriter.getF64FloatAttr(0.0)); + Value max = b.create(loc, f64Ty, + rewriter.getF64FloatAttr(1.0)); + + // iterate and sample class indices + Value result = args[0]; + Value finalResult = + rewriter + .create( + loc, cstZero, numSamples, cstOne, ValueRange{result}, + [&](OpBuilder &b, Location loc, Value i, ValueRange args) { + // Sample random float + Value uniformSample = + randomUniformF64(b, loc, i, key, min, max); + + // binary search in cdf to find our sample + Value left = b.create( + loc, i64Ty, b.getI64IntegerAttr(0)); + Value right = numCategories; + + auto checkCondition = [&](OpBuilder &b, Location loc, + ValueRange vals) { + Value left = vals[0]; + Value right = vals[1]; + + // while (right > left) + auto comparisonPredicate = arith::CmpIPredicateAttr::get( + b.getContext(), arith::CmpIPredicate::sgt); + Value loopCondition = b.create( + loc, comparisonPredicate, right, left); + b.create(loc, loopCondition, vals); + }; + + ValueRange whileResults = + b.create( + loc, TypeRange{i64Ty, i64Ty}, + ValueRange{left, right}, checkCondition, + [&](OpBuilder &b, Location loc, ValueRange vals) { + Value left = vals[0]; + Value right = vals[1]; + + Value two = b.create( + loc, i64Ty, b.getI64IntegerAttr(2)); + Value diff = + b.create(loc, right, left); + Value diffMid = + b.create(loc, diff, two); + Value midPointer = + b.create(loc, left, diffMid); + Type indexTy = b.getIndexType(); + Value midIndex = b.create( + loc, indexTy, midPointer); + + // branch and update search indices + auto thenBlock = [&](OpBuilder &b, + Location loc) { + // left = mid + 1 + Value newLeft = b.create( + loc, midPointer, cstOne); + + b.create( + loc, ValueRange{newLeft, right}); + }; + auto elseBlock = [&](OpBuilder &b, + Location loc) { + // right = mid + b.create( + loc, ValueRange{left, midPointer}); + }; + + Value cumProb = b.create( + loc, cdf, ValueRange{midIndex}); + auto cmpPredicate = + arith::CmpFPredicateAttr::get( + b.getContext(), + arith::CmpFPredicate::OLT); + Value branchCondition = b.create( + loc, cmpPredicate, cumProb, uniformSample); + ValueRange branchResults = + b.create(loc, branchCondition, + thenBlock, elseBlock) + .getResults(); + Value newLeft = branchResults[0]; + Value newRight = branchResults[1]; + + b.create( + loc, ValueRange{newLeft, newRight}); + }) + .getResults(); + + // sample_idx = left_pointer + Value samplePointer = whileResults[0]; + Value iIndex = + b.create(loc, indexTy, i); + + Value prevResult = args[0]; + Value newResult; + if (inputRank == 1) { + // result[i] = sample_idx + newResult = b.create( + loc, samplePointer, prevResult, ValueRange{iIndex}); + } else { + // result[j][i] = sample_idx + newResult = b.create( + loc, samplePointer, prevResult, + ValueRange{jIndex, iIndex}); + } + + b.create(loc, ValueRange{newResult}); + }) + .getResult(0); + + b.create(loc, ValueRange{finalResult}); + }; + + Value finalResultTensor = + rewriter + .create(loc, cstZero, numDistributions, cstOne, + ValueRange{resultTensor}, + multinomialComputation) + .getResult(0); + + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, + finalResultTensor); + + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -206,4 +524,6 @@ void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index ffb3350a0733..0e1f6426f958 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -9,18 +9,14 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -90,11 +86,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { bool isUnsigned = false; if (!isa(inElementType)) { if (isa(inElementType)) { - auto integerTy = op.getSelf() - .getType() - .template cast() - .getDtype() - .template dyn_cast(); + auto integerTy = dyn_cast( + cast(op.getSelf().getType()).getDtype()); isUnsigned = integerTy.isUnsigned(); } else { return rewriter.notifyMatchFailure( @@ -284,7 +277,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem, Type resultElementType) { - if (elem.getType().isa()) { + if (isa(elem.getType())) { return b.create(loc, elem); } @@ -380,11 +373,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (isa(resultElementType)) return b.create(loc, self, result); else if (isa(resultElementType)) { - IntegerType intType = max.getSelf() - .getType() - .cast() - .getDtype() - .dyn_cast(); + IntegerType intType = dyn_cast( + cast(max.getSelf().getType()).getDtype()); if (intType.isUnsigned()) return b.create(loc, self, result); if (intType.isSigned()) @@ -397,11 +387,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (isa(resultElementType)) return b.create(loc, self, result); else if (isa(resultElementType)) { - IntegerType intType = min.getSelf() - .getType() - .cast() - .getDtype() - .dyn_cast(); + IntegerType intType = dyn_cast( + cast(min.getSelf().getType()).getDtype()); if (intType.isUnsigned()) return b.create(loc, self, result); if (intType.isSigned()) @@ -661,9 +648,8 @@ class ConvertReductionOp : public ConversionPattern { return opInfo; Location loc = op->getLoc(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elemType = resultType.getElementType(); LogicalResult elemTypeCheck = validateReductionElementType(op, elemType, rewriter); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index add928392719..02853b14072a 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -9,17 +9,13 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -101,8 +97,12 @@ class ConvertAtenConstantPadNdOp Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = cast(newResultType).getElementType(); + + auto dstOriginalDtype = + cast(op.getType()).getDtype(); Value castedValue = - convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); + convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType, + std::nullopt, dstOriginalDtype); Type padType = tensor::PadOp::inferResultType( cast(self.getType()), staticLow, staticHigh); @@ -183,15 +183,13 @@ class ConvertAtenReplicationPad2dOp for (auto i : {TOP, VCENTER, BOTTOM}) { for (auto j : {LEFT, HCENTER, RIGHT}) { - auto constVtile{ + auto constVtile{dyn_cast_or_null( mlir::dyn_cast(vTile[i].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; + .getValue())}; - auto constHtile{ + auto constHtile{dyn_cast_or_null( mlir::dyn_cast(hTile[j].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; + .getValue())}; auto vSize = constVtile.getInt(); auto hSize = constHtile.getInt(); @@ -215,26 +213,38 @@ class ConvertAtenReplicationPad2dOp Value one = getConstant(rewriter, loc, 1, indexType); Value hDimSizeMinusOne = createSub(hDimSize, one); Value vDimSizeMinusOne = createSub(vDimSize, one); - SmallVector allOneStrides(numDims, one); - - SmallVector extractOffsetsLT(numDims, zero); - extractOffsetsLT[hDim] = zero; - extractOffsetsLT[vDim] = zero; - SmallVector extractShapeLR(numDims, one); - extractShapeLR[hDim] = one; - extractShapeLR[vDim] = vDimSize; - - SmallVector extractOffsetsRight(numDims, zero); - extractOffsetsRight[hDim] = hDimSizeMinusOne; - extractOffsetsRight[vDim] = zero; - - SmallVector extractOffsetsBottom(numDims, zero); - extractOffsetsBottom[hDim] = zero; - extractOffsetsBottom[vDim] = vDimSizeMinusOne; - - SmallVector extractShapeTB(numDims, one); - extractShapeTB[hDim] = hDimSize; - extractShapeTB[vDim] = one; + SmallVector allOneStridesVal(numDims, one); + SmallVector allOneStrides = + getAsOpFoldResult(allOneStridesVal); + + SmallVector extractOffsetsLTVal(numDims, zero); + extractOffsetsLTVal[hDim] = zero; + extractOffsetsLTVal[vDim] = zero; + SmallVector extractOffsetsLT = + getAsOpFoldResult(extractOffsetsLTVal); + SmallVector extractShapeLRVal(numDims, one); + extractShapeLRVal[hDim] = one; + extractShapeLRVal[vDim] = vDimSize; + SmallVector extractShapeLR = + getAsOpFoldResult(extractShapeLRVal); + + SmallVector extractOffsetsRightVal(numDims, zero); + extractOffsetsRightVal[hDim] = hDimSizeMinusOne; + extractOffsetsRightVal[vDim] = zero; + SmallVector extractOffsetsRight = + getAsOpFoldResult(extractOffsetsRightVal); + + SmallVector extractOffsetsBottomVal(numDims, zero); + extractOffsetsBottomVal[hDim] = zero; + extractOffsetsBottomVal[vDim] = vDimSizeMinusOne; + SmallVector extractOffsetsBottom = + getAsOpFoldResult(extractOffsetsBottomVal); + + SmallVector extractShapeTBVal(numDims, one); + extractShapeTBVal[hDim] = hDimSize; + extractShapeTBVal[vDim] = one; + SmallVector extractShapeTB = + getAsOpFoldResult(extractShapeTBVal); SmallVector tensorsLeft; SmallVector tensorsRight; @@ -246,24 +256,26 @@ class ConvertAtenReplicationPad2dOp Value vCenterLeftSlice = rewriter.create( loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); Value vLeftSlice = vCenterLeftSlice; + SmallVector extractIndices(numDims, zero); if (hasTopPadding) { - Value topLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, zero, zero}); + Value topLeftValue = + rewriter.create(loc, input, extractIndices); // pad vCenterLeftSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + lowPadding[vDim] = padInts[2]; vLeftSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); } if (hasBottomPadding) { - Value bottomLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); + extractIndices[vDim] = vDimSizeMinusOne; + Value bottomLeftValue = + rewriter.create(loc, input, extractIndices); // pad vLeftSlice at the bottom - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + highPadding[vDim] = padInts[3]; vLeftSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); } @@ -271,7 +283,7 @@ class ConvertAtenReplicationPad2dOp tensorsLeft.push_back(vLeftSlice); } Value leftPadTile = - rewriter.create(loc, 3, tensorsLeft); + rewriter.create(loc, hDim, tensorsLeft); tensorsRes.push_back(leftPadTile); } if (hasTopPadding) { @@ -289,33 +301,35 @@ class ConvertAtenReplicationPad2dOp tensorsCenter.push_back(bottomHcenterSlice); } } - centerTile = rewriter.create(loc, 2, tensorsCenter); + centerTile = rewriter.create(loc, vDim, tensorsCenter); tensorsRes.push_back(centerTile); if (hasRightPadding) { Value vCenterRightSlice = rewriter.create( loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); Value vRightSlice = vCenterRightSlice; + SmallVector extractIndices(numDims, zero); + extractIndices[hDim] = hDimSizeMinusOne; if (hasTopPadding) { Value topRightValue = rewriter.create( loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); // pad vCenterRightSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + lowPadding[vDim] = padInts[2]; vRightSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); } if (hasBottomPadding) { - Value bottomRightValue = rewriter.create( - loc, input, - ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); + extractIndices[vDim] = vDimSizeMinusOne; + Value bottomRightValue = + rewriter.create(loc, input, extractIndices); // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; + SmallVector lowPadding(numDims, 0); + SmallVector highPadding(numDims, 0); + highPadding[vDim] = padInts[3]; vRightSlice = torch_to_linalg::getPaddedTensor( op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue); @@ -324,10 +338,10 @@ class ConvertAtenReplicationPad2dOp tensorsRight.push_back(vRightSlice); } Value rightPadTile = - rewriter.create(loc, 3, tensorsRight); + rewriter.create(loc, hDim, tensorsRight); tensorsRes.push_back(rightPadTile); } - Value resTensor = rewriter.create(loc, 3, tensorsRes); + Value resTensor = rewriter.create(loc, hDim, tensorsRes); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, resTensor); return success(); @@ -373,8 +387,8 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { for (auto size : resultSize) resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size)); - auto resultType = typeConverter->convertType(op.getType()) - .template cast(); + auto resultType = + cast(typeConverter->convertType(op.getType())); Type resultElementType; if (isa(op.getDtype().getType())) { resultElementType = resultType.getElementType(); @@ -430,7 +444,7 @@ class ConvertAtenEmptyMemoryFormatOp op, "unimplemented: pin_memory must be either None or false"); // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) @@ -445,7 +459,7 @@ class ConvertAtenEmptyMemoryFormatOp } // TODO: Add support for device arg other than cpu. - if (!op.getDevice().getType().isa()) { + if (!isa(op.getDevice().getType())) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( @@ -457,7 +471,7 @@ class ConvertAtenEmptyMemoryFormatOp // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. - if (!op.getLayout().getType().isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( @@ -482,7 +496,7 @@ class ConvertAtenEmptyMemoryFormatOp auto resultType = cast(typeConverter->convertType(op.getType())); Type resultElementType; - if (op.getDtype().getType().isa()) { + if (isa(op.getDtype().getType())) { resultElementType = getDefaultDtypeForTorchScalar( Torch::FloatType::get(op->getContext())); } else { @@ -531,7 +545,7 @@ class ConvertAtenArangeStartStepOp // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( @@ -540,9 +554,8 @@ class ConvertAtenArangeStartStepOp Location loc = op.getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); Type dtype = resultType.getElementType(); Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index 1f8b2f980a9c..ab5fec18f9b2 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -9,17 +9,11 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -144,17 +138,16 @@ class ConvertAtenScalarToTensorLike : public ConversionPattern { requires_grad = tensorFloatOp.getRequiresGrad(); } // TODO: Dtype conversion. - if (!dtype.getType().isa()) + if (!isa(dtype.getType())) return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype"); // TODO: Device information. - if (!device.getType().isa()) + if (!isa(device.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None device information"); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value elemValProm = convertScalarToDtype(rewriter, loc, elemVal, outElementType); @@ -177,9 +170,8 @@ class ConvertPrimNumToTensorScalarOp if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type outElementType = resultType.getElementType(); Value elemVal = adaptor.getA(); Value elemValProm = diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index a4451041fb49..01b1d4b973b6 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -14,14 +14,10 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" @@ -47,6 +43,7 @@ class ConvertTorchToLinalg registry.insert(); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); } @@ -56,7 +53,7 @@ class ConvertTorchToLinalg ConversionTarget target(*context); target.addLegalDialect< linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect, - math::MathDialect, sparse_tensor::SparseTensorDialect, + math::MathDialect, scf::SCFDialect, sparse_tensor::SparseTensorDialect, tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>(); target.addLegalOp(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e369df0d066e..4ebdfbf94129 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -7,25 +7,26 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" -#include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" #include +#include #include using namespace mlir; @@ -152,59 +153,18 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } -template -static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); - - Type lhsDtype = lhs.getType(); - Type rhsDtype = rhs.getType(); - - // TODO: Type promotion in case of different `lhsDtype` and `rhsDtype` needs - // to be handled. - if (lhsDtype != rhsDtype) { - op.emitError("unimplemented: lhs and rhs dtype must be same"); - return nullptr; - } - - Type elementalType = cast(op.getSelf().getType()).getDtype(); - if constexpr (std::is_same()) { - return createLessThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThan(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createEqual(b, loc, elementalType, lhs, rhs); - } - if constexpr (std::is_same()) { - return createNotEqual(b, loc, elementalType, lhs, rhs); - } - llvm_unreachable("unimplemented: op type not supported"); -} +template +struct is_any_same : std::disjunction...> {}; template -static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, - Value lhs, Value rhs) { - static_assert(std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same() || - std::is_same(), - "unimplemented: op type not supported"); +static Value createCompareOp(OpBuilder &b, Location loc, OpTy op, Value lhs, + Value rhs) { + static_assert( + is_any_same(), + "unimplemented: op type not supported"); Type lhsDtype = lhs.getType(); Type rhsDtype = rhs.getType(); @@ -232,22 +192,22 @@ static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op, return nullptr; } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createLessThan(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createLessThanOrEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createGreaterThan(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createGreaterThanOrEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createEqual(b, loc, elementalType, lhs, rhs); } - if constexpr (std::is_same()) { + if constexpr (is_any_same()) { return createNotEqual(b, loc, elementalType, lhs, rhs); } llvm_unreachable("unimplemented: op type not supported"); @@ -350,6 +310,70 @@ Value createDivModePayload(OpBuilder &b, Location loc, return quotient; } +template +Value createRemainderPayload(OpBuilder &b, Location loc, + const TypeConverter *converter, + ValueRange payloadArgs, OpT op, + ArrayRef operands) { + static_assert( + llvm::is_one_of(), + "op must be a tensor/scalar remainder op"); + typename OpT::Adaptor adaptor(operands); + Type dtype = cast(converter->convertType(op.getType())) + .getElementType(); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype( + b, loc, + std::is_same_v ? operands[1] : payloadArgs[1], + dtype); + + // The remainder op we wish to create would look roughly like this: + // rem = a % b + // if rem != 0 AND (rem < 0 XOR b < 0) rem += b + // This is how python calucates remainders for floats and longs: + // https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/floatobject.c#L645 + // https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/longobject.c#L3662 + Value result; + if (isa(dtype)) { + Value remainder = b.create(loc, lhs, rhs); + + Value zero = b.create(loc, b.getZeroAttr(dtype)); + Value remainderNotEqualToZero = b.create( + loc, arith::CmpFPredicate::ONE, remainder, zero); + Value otherLessThanZero = + b.create(loc, arith::CmpFPredicate::OLT, rhs, zero); + Value remainderLessThanZero = b.create( + loc, arith::CmpFPredicate::OLT, remainder, zero); + Value xorCondition = + b.create(loc, otherLessThanZero, remainderLessThanZero); + Value condition = + b.create(loc, remainderNotEqualToZero, xorCondition); + Value fixedRemainder = b.create(loc, remainder, rhs); + result = + b.create(loc, condition, fixedRemainder, remainder); + } else { + assert(dtype.isInteger() && + "dtype should be a float or integer (signless or signed)"); + Value remainder = b.create(loc, lhs, rhs); + + Value zero = b.create(loc, b.getZeroAttr(dtype)); + Value remainderNotEqualToZero = + b.create(loc, arith::CmpIPredicate::ne, remainder, zero); + Value otherLessThanZero = + b.create(loc, arith::CmpIPredicate::slt, rhs, zero); + Value remainderLessThanZero = b.create( + loc, arith::CmpIPredicate::slt, remainder, zero); + Value xorCondition = + b.create(loc, otherLessThanZero, remainderLessThanZero); + Value condition = + b.create(loc, remainderNotEqualToZero, xorCondition); + Value fixedRemainder = b.create(loc, remainder, rhs); + result = + b.create(loc, condition, fixedRemainder, remainder); + } + return result; +} + static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, const TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { @@ -425,7 +449,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; - if (!clone.getMemoryFormat().getType().isa() && + if (!isa(clone.getMemoryFormat().getType()) && (!matchPattern(clone.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)) || (memoryFormat != torch_upstream::MemoryFormat::Contiguous && @@ -437,24 +461,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return payloadArgs[0]; } if (auto bitwiseAndTensor = dyn_cast(op)) { - if (bitwiseAndTensor.getType() - .cast() - .getDtype() - .isa()) { + if (isa( + cast(bitwiseAndTensor.getType()).getDtype())) { bitwiseAndTensor.emitError( "Bitwise_And does not support floating point dtype"); return nullptr; } - Type dtype = converter->convertType(bitwiseAndTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseAndTensor.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseAndScalar = dyn_cast(op)) { - Type dtype = converter->convertType(bitwiseAndScalar.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseAndScalar.getType())) .getElementType(); if (!isa(dtype)) { bitwiseAndScalar.emitError( @@ -472,32 +494,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, self, other); } if (auto bitwiseOrTensor = dyn_cast(op)) { - if (bitwiseOrTensor.getType() - .cast() - .getDtype() - .isa()) { + if (isa( + cast(bitwiseOrTensor.getType()).getDtype())) { bitwiseOrTensor.emitError( "Bitwise_Or does not support floating point dtype"); return nullptr; } - Type dtype = converter->convertType(bitwiseOrTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseOrTensor.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } if (auto bitwiseXorTensor = dyn_cast(op)) { - if (bitwiseXorTensor.getType() - .cast() - .getDtype() - .isa()) { + if (isa( + cast(bitwiseXorTensor.getType()).getDtype())) { bitwiseXorTensor.emitError( "Bitwise_Xor does not support floating point dtype"); return nullptr; } - Type dtype = converter->convertType(bitwiseXorTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseXorTensor.getType())) .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); @@ -505,8 +523,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto bitwiseRightShiftTensor = dyn_cast(op)) { - Type dtype = converter->convertType(bitwiseRightShiftTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseRightShiftTensor.getType())) .getElementType(); if (!isa(dtype)) { bitwiseRightShiftTensor.emitError( @@ -519,8 +537,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto bitwiseLeftShiftTensor = dyn_cast(op)) { - Type dtype = converter->convertType(bitwiseLeftShiftTensor.getType()) - .cast() + Type dtype = cast( + converter->convertType(bitwiseLeftShiftTensor.getType())) .getElementType(); if (!isa(dtype)) { bitwiseLeftShiftTensor.emitError( @@ -533,7 +551,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) { MLIRContext *context = op->getContext(); - Type floatDtype = mlir::FloatType::getF64(context); + Type floatDtype = mlir::Float64Type::get(context); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); Value zero = @@ -553,14 +571,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) { MLIRContext *context = op->getContext(); - Type floatDtype = mlir::FloatType::getF64(context); + Type floatDtype = mlir::Float64Type::get(context); Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } + if (auto complex = dyn_cast(op)) { + auto ctype = cast( + cast(converter->convertType(complex.getType())) + .getElementType()); + Type stype = ctype.getElementType(); + + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype); + return b.create(loc, ctype, lhs, rhs); + } if (isa(op)) { - if (payloadArgs[0].getType().isa()) + if (isa(payloadArgs[0].getType())) return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); } @@ -656,20 +684,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { - if (!round.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(round.getType()).getDtype())) { round.emitError("unimplemented: non-floating point dtype"); return nullptr; } return b.create(loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { - if (!prelu.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(prelu.getType()).getDtype())) { prelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -688,10 +712,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { - if (!gelu.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(gelu.getType()).getDtype())) { gelu.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -735,10 +757,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } if (auto geluBackward = dyn_cast(op)) { - if (!geluBackward.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(geluBackward.getType()).getDtype())) { geluBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -773,10 +793,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto hardtanhBackward = dyn_cast(op)) { AtenHardtanhBackwardOp::Adaptor adaptor(operands); - if (!hardtanhBackward.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(hardtanhBackward.getType()).getDtype())) { hardtanhBackward.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -811,6 +829,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (isa(dtype)) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); + } else if (dtype.isInteger(1)) { + Value scaled = b.create(loc, rhs, alpha); + return b.create(loc, lhs, scaled); } else { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); @@ -839,6 +860,28 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, scaled); } } + if (auto lshiftScalar = dyn_cast(op)) { + Type dtype = + cast(converter->convertType(lshiftScalar.getType())) + .getElementType(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = + convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/operands[1].getType(), + /*dstOriginalDtype=*/dtype); + return b.create(loc, self, other); + } + if (auto rshiftScalar = dyn_cast(op)) { + Type dtype = + cast(converter->convertType(rshiftScalar.getType())) + .getElementType(); + Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value other = + convertScalarToDtype(b, loc, operands[1], dtype, + /*srcOriginalDtype=*/operands[1].getType(), + /*dstOriginalDtype=*/dtype); + return b.create(loc, self, other); + } if (auto subScalar = dyn_cast(op)) { Type dtype = cast(converter->convertType(subScalar.getType())) @@ -911,28 +954,22 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, lhs, rhs); } if (auto ltTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, ltTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]); } if (auto leTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, leTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, leTensor, payloadArgs[0], payloadArgs[1]); } if (auto gtTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, gtTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, gtTensor, payloadArgs[0], payloadArgs[1]); } if (auto geTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, geTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, geTensor, payloadArgs[0], payloadArgs[1]); } if (auto eqTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, eqTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, eqTensor, payloadArgs[0], payloadArgs[1]); } if (auto neTensor = dyn_cast(op)) { - return createCompareTensorOp(b, loc, neTensor, payloadArgs[0], - payloadArgs[1]); + return createCompareOp(b, loc, neTensor, payloadArgs[0], payloadArgs[1]); } if (auto div = dyn_cast(op)) { AtenDivTensorOp::Adaptor adaptor(operands); @@ -970,10 +1007,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto pow = dyn_cast(op)) { - if (!pow.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(pow.getType()).getDtype())) { pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -986,12 +1021,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = cast(converter->convertType(pow.getType())) .getElementType(); if (!isa(dtype)) { + // The result type is integer when both operands are integer. + // Torch then uses the following implementation: + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Pow.h pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + Type powType = dtype; + if (payloadArgs[0].getType().isInteger() || + payloadArgs[1].getType().isInteger()) + powType = mlir::Float64Type::get(op->getContext()); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType); + auto powOp = b.create(loc, lhs, rhs); + return convertScalarToDtype(b, loc, powOp, dtype); } if (auto imag = dyn_cast(op)) { @@ -1017,27 +1060,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto gtScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, gtScalar, payloadArgs[0], operands[1]); } if (auto geScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, geScalar, payloadArgs[0], operands[1]); } if (auto eqScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, eqScalar, payloadArgs[0], operands[1]); } if (auto neScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, neScalar, payloadArgs[0], operands[1]); } if (auto ltScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, ltScalar, payloadArgs[0], operands[1]); } if (auto leScalar = dyn_cast(op)) { - return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]); + return createCompareOp(b, loc, leScalar, payloadArgs[0], operands[1]); } if (auto whereSelf = dyn_cast(op)) { @@ -1050,10 +1093,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto lerp = dyn_cast(op)) { - if (!lerp.getType() - .cast() - .getDtype() - .isa()) { + if (!isa( + cast(lerp.getType()).getDtype())) { lerp.emitError("unimplemented: non-floating point dtype"); return nullptr; } @@ -1067,9 +1108,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto minimum = dyn_cast(op)) { Type dtype = cast(minimum.getType()).getDtype(); - Type elemTy = converter->convertType(minimum.getType()) - .cast() - .getElementType(); + Type elemTy = + cast(converter->convertType(minimum.getType())) + .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); @@ -1077,9 +1118,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto maximum = dyn_cast(op)) { Type dtype = cast(maximum.getType()).getDtype(); - Type elemTy = converter->convertType(maximum.getType()) - .cast() - .getElementType(); + Type elemTy = + cast(converter->convertType(maximum.getType())) + .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); @@ -1089,8 +1130,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); - if (min.getType().isa() || - max.getType().isa()) { + if (isa(min.getType()) || + isa(max.getType())) { clamp.emitError("unimplemented: runtime optional type"); return nullptr; } @@ -1128,9 +1169,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( }; auto result = payloadArgs[0]; - if (!min.getType().isa()) + if (!isa(min.getType())) result = cmpSelect(result, min, /*getMax=*/false); - if (!max.getType().isa()) + if (!isa(max.getType())) result = cmpSelect(result, max, /*getMax=*/true); return result; } @@ -1138,8 +1179,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( AtenClampTensorOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); - if (min.getType().isa() || - max.getType().isa()) { + if (isa(min.getType()) || + isa(max.getType())) { clampTensor.emitError("unimplemented: runtime optional type"); return nullptr; } @@ -1148,7 +1189,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); bool isMinNone = true; auto result = payloadArgs[0]; - if (!min.getType().isa()) { + if (!isa(min.getType())) { isMinNone = false; auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value pred; @@ -1166,7 +1207,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } result = b.create(loc, pred, minPromoted, result); } - if (!max.getType().isa()) { + if (!isa(max.getType())) { max = isMinNone ? payloadArgs[1] : payloadArgs[2]; auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); Value pred; @@ -1220,6 +1261,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto atenToDtype = dyn_cast(op)) { Value input = payloadArgs[0]; + Type inputElementType = + cast(atenToDtype.getSelf().getType()).getDtype(); Type dtype = cast(converter->convertType(atenToDtype.getType())) .getElementType(); @@ -1238,7 +1281,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } resultElementType = *maybeResultElementType; Value result = convertScalarToDtype(b, loc, input, dtype, - /*srcOriginalDtype=*/std::nullopt, + /*srcOriginalDtype=*/inputElementType, /*dstOriginalDtype=*/resultElementType); return result; } @@ -1255,67 +1298,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, self, other); } if (auto remScalar = dyn_cast(op)) { - Type newResultType = converter->convertType(remScalar.getType()) - .cast() - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, operands[1], newResultType); - Value result; - - if (isa(newResultType)) { - result = b.create(loc, self, other); - } else if (isa(newResultType)) { - result = b.create(loc, self, other); - } else { - remScalar.emitError( - "Unsupported type encountered for AtenRemainderScalarOp."); - } - - return result; + return createRemainderPayload(b, loc, converter, payloadArgs, remScalar, + operands); } if (auto remTensor = dyn_cast(op)) { - Type newResultType = converter->convertType(remTensor.getType()) - .cast() - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); - Value result; - - if (isa(newResultType)) { - result = b.create(loc, self, other); - } else if (isa(newResultType)) { - result = b.create(loc, self, other); - } else { - remTensor.emitError( - "Unsupported type encountered for AtenRemainderTensorOp."); - } - - return result; - } - if (auto fmod = dyn_cast(op)) { - Type newResultType = converter->convertType(fmod.getType()) - .cast() - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); - Value result; - - if (isa(newResultType)) { - Value n = b.create(loc, self, other); - n = b.create(loc, n); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else if (isa(newResultType)) { - Value n = b.create(loc, self, other); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else { - fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); - } - return result; + return createRemainderPayload(b, loc, converter, payloadArgs, remTensor, + operands); } if (auto reciprocal = dyn_cast(op)) { Type dtype = @@ -1423,9 +1411,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto bitwiseNot = dyn_cast(op)) { - Type elementType = converter->convertType(bitwiseNot.getType()) - .cast() - .getElementType(); + Type elementType = + cast(converter->convertType(bitwiseNot.getType())) + .getElementType(); if (isa(elementType)) { bitwiseNot.emitError("Bitwise_Not does not support floating point dtype"); return nullptr; @@ -1509,7 +1497,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( scale = b.create(loc, valueTy, scale); value = b.create(loc, value, scale); - value = b.create(loc, value); + value = b.create(loc, value); value = b.create(loc, value, zp); auto destTy = payloadArgs[1].getType(); @@ -1528,12 +1516,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(valueTy, minI)); Value maxVal = b.create(loc, b.getFloatAttr(valueTy, maxI)); - Value minCmp = - b.create(loc, arith::CmpFPredicate::ULT, value, minVal); - Value maxCmp = - b.create(loc, arith::CmpFPredicate::UGT, value, maxVal); - value = b.create(loc, minCmp, minVal, value); - value = b.create(loc, maxCmp, maxVal, value); + value = b.create(loc, value, minVal); + value = b.create(loc, value, maxVal); if (isUnsigned) { value = b.create(loc, destTy, value); @@ -1544,6 +1528,48 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return value; } + if (auto isClose = dyn_cast(op)) { + double rtol, atol; + bool equalNan; + if (!matchPattern(isClose.getRtol(), m_TorchConstantFloat(&rtol))) { + isClose.emitError("rtol must be a scalar constant"); + return nullptr; + } + if (!matchPattern(isClose.getAtol(), m_TorchConstantFloat(&atol))) { + isClose.emitError("atol must be a scalar constant"); + return nullptr; + } + if (!matchPattern(isClose.getEqualNan(), m_TorchConstantBool(&equalNan))) { + isClose.emitError("unimplemented: equal_nan is expected to be false"); + return nullptr; + } + auto lhsType = mlir::dyn_cast(payloadArgs[0].getType()); + auto rhsType = mlir::dyn_cast(payloadArgs[1].getType()); + if (!lhsType || !rhsType) { + isClose.emitError("unimplemented: only FP element type is supported"); + return nullptr; + } + // Choose the widest float type as compute type. + auto computeType = + lhsType.getWidth() > rhsType.getWidth() ? lhsType : rhsType; + computeType = computeType.getWidth() >= 32 ? computeType : b.getF32Type(); + auto cvtArg0 = convertScalarToDtype(b, loc, payloadArgs[0], computeType); + auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType); + // Reference to the definition of torch.isclose: + // ∣input − other∣ <= atol + rtol × ∣other∣ + auto diff = b.create(loc, computeType, cvtArg0, cvtArg1); + auto absDiff = b.create(loc, computeType, diff); + auto cstRtol = + b.create(loc, b.getFloatAttr(computeType, rtol)); + auto absOther = b.create(loc, computeType, cvtArg1); + auto mul = b.create(loc, computeType, cstRtol, absOther); + auto cstAtol = + b.create(loc, b.getFloatAttr(computeType, atol)); + auto threshold = b.create(loc, computeType, cstAtol, mul); + return b.create(loc, arith::CmpFPredicate::ULE, absDiff, + threshold); + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1586,11 +1612,12 @@ class ConvertElementwiseOp : public ConversionPattern { AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, - AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp, - AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, + AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, + Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, @@ -1602,7 +1629,7 @@ class ConvertElementwiseOp : public ConversionPattern { AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, - AtenQuantizePerTensorOp>(op)) + AtenQuantizePerTensorOp, AtenIscloseOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1610,10 +1637,9 @@ class ConvertElementwiseOp : public ConversionPattern { Location loc = op->getLoc(); auto tensorOperands = llvm::to_vector<6>(llvm::make_filter_range( - operands, [](Value v) { return v.getType().isa(); })); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + operands, [](Value v) { return isa(v.getType()); })); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), @@ -1660,7 +1686,7 @@ class ConvertAtenNllLossForwardOp return rewriter.notifyMatchFailure(op, "dim must be constant"); // TODO: Incorporate the weight argument. - if (!weight.getType().isa()) + if (!isa(weight.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented, the weight operand is not incorporated."); @@ -1675,9 +1701,8 @@ class ConvertAtenNllLossForwardOp return rewriter.notifyMatchFailure( op, "expected input and target to be rank <= 2"); } - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); Value zeroVal = rewriter.create( @@ -1718,7 +1743,27 @@ class ConvertAtenNllLossForwardOp if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { - Value numOfElems = getTensorSize(rewriter, loc, finalRes); + + Value zeroIVal = rewriter.create( + loc, rewriter.getZeroAttr(rewriter.getI32Type())); + auto countInfo = torch_to_linalg::ReductionOpInfo{false, target, dimSet}; + Value numOfElems = torch_to_linalg::createReductionLinalgGeneric( + rewriter, loc, countInfo, + /*initElem=*/zeroIVal, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value targetVal = args[0]; + Value indTarget = rewriter.create( + loc, rewriter.getIndexType(), targetVal); + Value cmpEq = rewriter.create( + loc, arith::CmpIPredicate::ne, indTarget, ignoreIndexVal); + cmpEq = rewriter.create(loc, rewriter.getI32Type(), + cmpEq); + Value add = rewriter.create(loc, args[1], cmpEq); + rewriter.create(loc, add); + }); + + numOfElems = rewriter.create( + loc, rewriter.getI32Type(), numOfElems, ArrayRef{}); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet}; @@ -1951,7 +1996,7 @@ class ConvertAtenNllLossBackwardOp Value input = adaptor.getSelf(); Value target = adaptor.getTarget(); Value weight = adaptor.getWeight(); - bool weightIsNone = op.getWeight().getType().isa(); + bool weightIsNone = isa(op.getWeight().getType()); Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex()); Value totalWeight = adaptor.getTotalWeight(); @@ -2072,9 +2117,8 @@ class ConvertAtenNllLossBackwardOp }) ->getResult(0); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, gradInput); return success(); } @@ -2217,9 +2261,8 @@ class ConvertTensorStaticInfoCastOp LogicalResult matchAndRewrite(TensorStaticInfoCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand()); return success(); @@ -2246,7 +2289,7 @@ class ConvertLogitOp : public OpConversionPattern { if (succeeded(checkNotNone(rewriter, op, eps))) handleEps = true; - if (handleEps && !eps.getType().isa()) { + if (handleEps && !isa(eps.getType())) { op.emitError("Logit does not support non-floating point type"); return failure(); } @@ -2320,9 +2363,8 @@ class ConvertAtenIntReprOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); @@ -2365,8 +2407,8 @@ class ConvertDequantizePerChannel zeropoint = converter->materializeTargetConversion( rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); - auto resultType = converter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + converter->convertType(op->getResult(0).getType())); llvm::SmallVector dynSizes; for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { @@ -2406,6 +2448,10 @@ class ConvertDequantizePerChannel } else if (zeropointDTy.isSignedInteger(8)) { zeropoint = b.create(loc, b.getI32Type(), zeropoint); + } else if (zeropointDTy.isInteger(64)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + op->emitWarning() << "truncated zero point from 64 to 32 bit"; } Value sub = rewriter.create(loc, operand, zeropoint); @@ -2450,9 +2496,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Location loc = op->getLoc(); Type int64type = rewriter.getI64Type(); Type floatType = rewriter.getF32Type(); - Value zeroIndex = rewriter.create(loc, 0); Value oneIndex = rewriter.create(loc, 1); - Value twoIndex = rewriter.create(loc, 2); Value zeroFloat = rewriter.create( loc, rewriter.getFloatAttr(floatType, 0.0)); Value oneFloat = rewriter.create( @@ -2461,7 +2505,6 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { loc, rewriter.getFloatAttr(floatType, 2.0)); Value input = adaptor.getInput(); auto inputType = cast(input.getType()); - auto inputShape = inputType.getShape(); Value innerDim0a = rewriter.create(loc, input, 2); Value innerDim1a = rewriter.create(loc, input, 3); Value innerDim0b = @@ -2482,42 +2525,21 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { rewriter.create(loc, innerDim1d, twoFloat); Value grid = adaptor.getGrid(); auto gridType = cast(grid.getType()); - auto gridShape = gridType.getShape(); auto gridRank = gridType.getRank(); - SmallVector extractGridOffsets0(gridRank, zeroIndex); - SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); - SmallVector extractGridStride(gridRank, oneIndex); - int64_t lastGridDim = gridRank - 1; - extractGridShape[lastGridDim] = oneIndex; - extractGridStride[lastGridDim] = twoIndex; - SmallVector extractGridOffsets1(gridRank, zeroIndex); - extractGridOffsets1[lastGridDim] = oneIndex; - SmallVector gridShapeExtracted(gridShape); - gridShapeExtracted.back() = 1; - SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], - gridShape[2]}; - auto grid0 = rewriter.create( - loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); - auto grid1 = rewriter.create( - loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); - SmallVector associations{ReassociationIndices{0}, - ReassociationIndices{1}, - ReassociationIndices{2, 3}}; - auto gridCollapsed0 = - rewriter.create(loc, grid0, associations); - auto gridCollapsed1 = - rewriter.create(loc, grid1, associations); - AffineMap gridMap = AffineMap::get(4, 0, - {rewriter.getAffineDimExpr(0), - rewriter.getAffineDimExpr(2), - rewriter.getAffineDimExpr(3)}, - op->getContext()); - SmallVector gridMaps{gridMap, gridMap, - rewriter.getMultiDimIdentityMap(gridRank)}; + SmallVector gridMaps{ + AffineMap::get( + 4, 0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(0)}, + op->getContext()), + AffineMap::get( + 4, 0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3), rewriter.getAffineConstantExpr(1)}, + op->getContext()), + rewriter.getMultiDimIdentityMap(inputType.getRank())}; SmallVector gridIterators( gridRank, utils::IteratorType::parallel); - SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], - gridShape[2]}; auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, Value idxB, Value idxC, Value idxD) -> Value { SmallVector index{idxA, idxB, idxC, idxD}; @@ -2556,25 +2578,24 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { return res; }; - auto resultType = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); - SmallVector resultSize{}; + auto resultType = cast( + getTypeConverter()->convertType(op.getResult().getType())); + Value alignCorners = adaptor.getAlignCorners(); + Value interMode = adaptor.getInterpolationMode(); + SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) - resultSize.push_back(rewriter.create(loc, input, 0)); + dynamicSizes.push_back(rewriter.create(loc, input, 0)); if (resultType.isDynamicDim(1)) - resultSize.push_back(rewriter.create(loc, input, 1)); + dynamicSizes.push_back(rewriter.create(loc, input, 1)); if (resultType.isDynamicDim(2)) - resultSize.push_back(rewriter.create(loc, grid, 1)); + dynamicSizes.push_back(rewriter.create(loc, grid, 1)); if (resultType.isDynamicDim(3)) - resultSize.push_back(rewriter.create(loc, grid, 2)); - Value alignCorners = adaptor.getAlignCorners(); - Value interMode = adaptor.getInterpolationMode(); - Value resultFinal = - rewriter.create(loc, resultType, resultSize); + dynamicSizes.push_back(rewriter.create(loc, grid, 2)); + tensor::EmptyOp emptyOp = + rewriter.create(loc, resultType, dynamicSizes); auto sGrid = rewriter.create( - loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, - ValueRange(resultFinal), gridMaps, gridIterators, + loc, TypeRange{resultType}, ValueRange{grid, grid}, ValueRange(emptyOp), + gridMaps, gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) { Value gr0 = args[1]; Value gr1 = args[0]; @@ -2672,224 +2693,446 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace -static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH, - Value outputSizeW, Value input, - Value inputSizeH, Value inputSizeW) { +static Value nearestInterpolate(OpBuilder &b, Location loc, + SmallVector outputSizes, Value input, + SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr, std::string nearestMode) { - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - Value yOut = b.create(loc, 2); - Value xOut = b.create(loc, 3); - - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); - - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); - - // scale = length_resized / length_original - // x_original = x_resized / scale - Value hScale = b.create(loc, outputSizeHFP, inputHFP); - Value wScale = b.create(loc, outputSizeWFP, inputWFP); - - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value yProj = b.create(loc, yOutFP, hScale); + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value xProj = b.create(loc, xOutFP, wScale); + for (unsigned i = 2; i < inputRank; i++) { + Value outIndex = indices[i]; - // get nearest pixel using floor - Value yNearestFP = b.create(loc, yProj); - Value xNearestFP = b.create(loc, xProj); + Value inputSizeFP = + b.create(loc, b.getF32Type(), inputSizes[i - 2]); - Value yNearestInt = - b.create(loc, b.getI64Type(), yNearestFP); - Value yNearest = - b.create(loc, b.getIndexType(), yNearestInt); + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i - 2]); - Value xNearestInt = - b.create(loc, b.getI64Type(), xNearestFP); - Value xNearest = - b.create(loc, b.getIndexType(), xNearestInt); + // scale = length_resized / length_original + // x_original = x_resized / scale + Value scale; + if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputSizeFP); + else + scale = scaleValues[i - 2]; + + Value outInt = b.create(loc, b.getI64Type(), outIndex); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value proj; + if (coordStr.empty() || coordStr == "_asymmetric") { + proj = b.create(loc, outFP, scale); + } else if (coordStr == "_half_pixel") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value add = b.create(loc, outFP, cstHalf); + Value div = b.create(loc, add, scale); + proj = b.create(loc, div, cstHalf); + } else { + llvm_unreachable("Unsupported coordination transformation mode"); + } + + Value nearestFP; + // get nearest pixel using floor + if (nearestMode == "floor" || nearestMode == "") { + nearestFP = b.create(loc, proj); + } else if (nearestMode == "round_prefer_floor") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value floor = b.create(loc, proj); + Value ceil = b.create(loc, proj); + Value decimal = b.create(loc, proj, floor); + Value cmp = b.create(loc, arith::CmpFPredicate::ULE, + decimal, cstHalf); + nearestFP = b.create(loc, cmp, floor, ceil); + } else if (nearestMode == "round_prefer_ceil") { + Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); + Value cstOne = b.create(loc, b.getF32FloatAttr(1)); + Value floor = b.create(loc, proj); + Value ceil = b.create(loc, proj); + Value decimal = b.create(loc, proj, floor); + Value cmp = b.create(loc, arith::CmpFPredicate::UGE, + decimal, cstHalf); + nearestFP = b.create(loc, cmp, ceil, floor); + Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); + // don't extract out of bounds + nearestFP = b.create(loc, nearestFP, inputSizeMOne); + } else if (nearestMode == "ceil") { + Value cstOne = b.create(loc, b.getF32FloatAttr(1)); + Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); + nearestFP = b.create(loc, proj); + nearestFP = b.create(loc, nearestFP, inputSizeMOne); + } else { + llvm_unreachable("Unsupported nearest mode"); + } + Value nearestInt = + b.create(loc, b.getI64Type(), nearestFP); + Value nearest = + b.create(loc, b.getIndexType(), nearestInt); - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); + indices[i] = nearest; } - - int hDimOffset = 2; - indices[hDimOffset] = yNearest; - indices[hDimOffset + 1] = xNearest; Value retVal = b.create(loc, input, indices); return retVal; } -static Value BilinearInterpolate(OpBuilder &b, - Aten__InterpolateSizeListScaleListOp op, - Location loc, Value outputSizeH, - Value outputSizeW, Value input, - Value inputSizeH, Value inputSizeW) { - int hDimOffset = 2; - auto inputType = input.getType().cast(); +static SmallVector coordinateTransform( + OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc, + SmallVector outputSizes, Value input, SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, bool alignCornersBool, + SmallVector indices, bool clip) { + + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - Value cstOneEps = b.create(loc, b.getF32FloatAttr(1.001)); Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - Value yOut = b.create(loc, 2); - Value xOut = b.create(loc, 3); + SmallVector proj; + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + // length_resized + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i]); + // scale = length_resized/length_original + Value scale; + if (alignCornersBool) { + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value outputSizeSubOne = + b.create(loc, outputSizeFP, cstOneFloat); + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeSubOne, zero); + scale = b.create(loc, inputSubOne, outputSizeSubOne); + scale = b.create(loc, cmp, zero, scale); + coordStr = "_align_corners"; + } else if (scaleValues.empty()) + scale = b.create(loc, outputSizeFP, inputFP); + else + scale = scaleValues[i]; + // y_resized + Value outInt = b.create(loc, b.getI64Type(), + indices[i + dimOffset]); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value preClip; + if (coordStr == "_align_corners") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_asymmetric") { + preClip = b.create(loc, outFP, scale); + } + if (coordStr == "_pytorch_half_pixel" || coordStr == "" || + coordStr == "_half_pixel_symmetric") { + // half-pixel modes + // y_resized + 0.5 + Value outPlusHalf = b.create(loc, outFP, cstHalf); + // (y_resized + 0.5) / scale + Value outDivScale = b.create(loc, outPlusHalf, scale); + // _ - 0.5 + preClip = b.create(loc, outDivScale, cstHalf); + } + // for half_pixel_symmetric, need to compute offset from raw scales + if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) { + Value outputSizeFromScale = b.create(loc, inputFP, scale); + Value adjustment = + b.create(loc, outputSizeFP, outputSizeFromScale); + Value cstTwo = b.create(loc, b.getF32FloatAttr(2.0)); + Value center = b.create(loc, inputFP, cstTwo); + Value oneMAdjustment = + b.create(loc, cstOneFloat, adjustment); + Value offset = b.create(loc, center, oneMAdjustment); + preClip = b.create(loc, offset, preClip); + } + // for pytorch half pixel , special case for length_resized == 1: + if (coordStr == "_pytorch_half_pixel") { + Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, + outputSizeFP, cstOneFloat); + preClip = b.create(loc, cmp, zero, preClip); + } + if (clip) { + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) + Value max = b.create(loc, preClip, zero); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1]. + // proj is properly within the input image. + proj.push_back(b.create(loc, max, inputSubOne)); + } else { + proj.push_back(preClip); + } + } + return proj; +} + +static Value bilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); bool alignCornersBool; matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - Value yProj, xProj; - if (alignCornersBool) { - // x_original = x_resized * (length_original - 1) / (length_resized - 1) - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value inputHSubOne = b.create(loc, inputHFP, cstOneFloat); - Value outputSizeHSubOne = - b.create(loc, outputSizeHFP, cstOneFloat); - Value hScale = - b.create(loc, inputHSubOne, outputSizeHSubOne); - Value yProjBeforeClamp = b.create(loc, yOutFP, hScale); - Value yMax = b.create(loc, yProjBeforeClamp, zero); - Value outputSizeHSubOneEps = - b.create(loc, outputSizeHFP, cstOneEps); - yProj = b.create(loc, outputSizeHSubOneEps, yMax); - - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value inputWSubOne = b.create(loc, inputWFP, cstOneFloat); - Value outputSizeWSubOne = - b.create(loc, outputSizeWFP, cstOneFloat); - Value wScale = - b.create(loc, inputWSubOne, outputSizeWSubOne); - Value xProjBeforeClamp = b.create(loc, xOutFP, wScale); - Value xMax = b.create(loc, xProjBeforeClamp, zero); - Value outputSizeWSubOneEps = - b.create(loc, outputSizeWFP, cstOneEps); - xProj = b.create(loc, outputSizeWSubOneEps, xMax); - } else { - // y_original = (y_resized + 0.5) / scale - 0.5 - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value hScale = b.create(loc, outputSizeHFP, inputHFP); - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value yPlusHalf = b.create(loc, yOutFP, cstHalf); - Value yDivScale = b.create(loc, yPlusHalf, hScale); - Value ySubHalf = b.create(loc, yDivScale, cstHalf); - Value yMax = b.create(loc, ySubHalf, zero); - Value inputHSubOne = b.create(loc, inputHFP, cstOneEps); - yProj = b.create(loc, yMax, inputHSubOne); - - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); - Value wScale = b.create(loc, outputSizeWFP, inputWFP); - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value xPlusHalf = b.create(loc, xOutFP, cstHalf); - Value xDivScale = b.create(loc, xPlusHalf, wScale); - Value xSubHalf = b.create(loc, xDivScale, cstHalf); - // clamp - Value xMax = b.create(loc, xSubHalf, zero); - Value inputWSubOne = b.create(loc, inputWFP, cstOneEps); - xProj = b.create(loc, xMax, inputWSubOne); - } - Value yLow = b.create(loc, yProj); - Value yProjPlusOne = b.create(loc, cstOneFloat, yProj); - Value yHigh = b.create(loc, yProjPlusOne); - - Value xLow = b.create(loc, xProj); - Value xProjPlusOne = b.create(loc, cstOneFloat, xProj); - Value xHigh = b.create(loc, xProjPlusOne); - SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { indices.push_back(b.create(loc, i)); } - Value yLowInt = b.create(loc, b.getI64Type(), yLow); - Value yLowIdx = b.create(loc, b.getIndexType(), yLowInt); - - Value xLowInt = b.create(loc, b.getI64Type(), xLow); - Value xLowIdx = b.create(loc, b.getIndexType(), xLowInt); - Value yHighInt = b.create(loc, b.getI64Type(), yHigh); - Value yHighIdx = - b.create(loc, b.getIndexType(), yHighInt); - - Value xHighInt = b.create(loc, b.getI64Type(), xHigh); - Value xHighIdx = - b.create(loc, b.getIndexType(), xHighInt); - - indices[hDimOffset] = yLowIdx; - indices[hDimOffset + 1] = xLowIdx; + SmallVector proj, high, low, highFP, lowFP; + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices, + true); + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + + // for bilinear interpolation, we look for the nearest indices below and + // above proj + lowFP.push_back(b.create(loc, proj[i])); + Value projPlusOne = b.create(loc, cstOneFloat, proj[i]); + highFP.push_back(b.create(loc, projPlusOne)); + + Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); + low.push_back(b.create(loc, b.getIndexType(), lowInt)); + + // highFP could be out-of-bounds, so make sure to clip it down before + // extracting. If highFP actually gets clipped here, then high[i] will + // extract at the last pixel, but will treat it as if it were extracted from + // one further position when computing the interpolation weights. + Value highExtract = + b.create(loc, projPlusOne, inputSubOne); + highExtract = b.create(loc, b.getI64Type(), highExtract); + high.push_back( + b.create(loc, b.getIndexType(), highExtract)); + } + + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = low[1]; Value p00 = b.create(loc, input, indices); - indices[hDimOffset] = yLowIdx; - indices[hDimOffset + 1] = xHighIdx; + indices[dimOffset] = low[0]; + indices[dimOffset + 1] = high[1]; Value p01 = b.create(loc, input, indices); - indices[hDimOffset] = yHighIdx; - indices[hDimOffset + 1] = xLowIdx; + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = low[1]; Value p10 = b.create(loc, input, indices); - indices[hDimOffset] = yHighIdx; - indices[hDimOffset + 1] = xHighIdx; + indices[dimOffset] = high[0]; + indices[dimOffset + 1] = high[1]; Value p11 = b.create(loc, input, indices); - // p00 p01 - // p10 p11 - // (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) / - // (xhigh - xlow) * p01 - Value xHighMinusxProj = b.create(loc, xHigh, xProj); - Value xHighMinusxLow = b.create(loc, xHigh, xLow); - Value w0 = b.create(loc, xHighMinusxProj, xHighMinusxLow); - Value lhs = b.create(loc, w0, p00); - - Value xProjMinusxLow = b.create(loc, xProj, xLow); - Value w1 = b.create(loc, xProjMinusxLow, xHighMinusxLow); - Value rhs = b.create(loc, w1, p01); - - Value xInter = b.create(loc, lhs, rhs); + // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), + // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. + // We interpolate via the weighted average of pij by weights Aij + // the formula is retval = Sum(pij*Aij for i and j in range(2)) + // Note: we do not need to divide by total rect area == 1 + + // lengths : Aij == dyi*dxj + Value dy0 = b.create(loc, highFP[0], proj[0]); + Value dy1 = b.create(loc, proj[0], lowFP[0]); + Value dx0 = b.create(loc, highFP[1], proj[1]); + Value dx1 = b.create(loc, proj[1], lowFP[1]); + + // left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01) + Value dx0p00 = b.create(loc, dx0, p00); + Value dx1p01 = b.create(loc, dx1, p01); + Value sum = b.create(loc, dx0p00, dx1p01); + Value left = b.create(loc, dy0, sum); + // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) + Value dx0p10 = b.create(loc, dx0, p10); + Value dx1p11 = b.create(loc, dx1, p11); + sum = b.create(loc, dx0p10, dx1p11); + Value right = b.create(loc, dy1, sum); + + return b.create(loc, left, right); +} - // (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) / - // (xhigh - xlow) * p11 - lhs = b.create(loc, w0, p10); - rhs = b.create(loc, w1, p11); +static Value bicubicInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); - Value xInter1 = b.create(loc, lhs, rhs); + Value inputFPH = + b.create(loc, b.getF32Type(), inputSizes[0]); + Value inputFPW = + b.create(loc, b.getF32Type(), inputSizes[1]); - // (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow) - // / (yhigh - ylow) * xInter1 - Value yHighMinusyProj = b.create(loc, yHigh, yProj); - Value yHighMinusyLow = b.create(loc, yHigh, yLow); - w0 = b.create(loc, yHighMinusyProj, yHighMinusyLow); - lhs = b.create(loc, w0, xInter); + Value a = b.create(loc, b.getF32FloatAttr(-0.75)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstTwoFloat = b.create(loc, b.getF32FloatAttr(2.0)); + Value cstThreeFloat = + b.create(loc, b.getF32FloatAttr(3.0)); + Value cstFourFloat = b.create(loc, b.getF32FloatAttr(4.0)); + Value cstFiveFloat = b.create(loc, b.getF32FloatAttr(5.0)); + Value cstEightFloat = + b.create(loc, b.getF32FloatAttr(8.0)); + + // (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1) + auto WeightLessThanEqualOne = [&](Value xDistance) -> Value { + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + + Value lessEqualOne = b.create(loc, a, cstTwoFloat); + lessEqualOne = b.create(loc, xDistanceCubed, lessEqualOne); + Value aPlusThree = b.create(loc, a, cstThreeFloat); + aPlusThree = b.create(loc, xDistanceSquared, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, cstOneFloat); + + return lessEqualOne; + }; + + // a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2) + auto WeightLessThanTwo = [&](Value xDistance) -> Value { + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + // a|x|^3 + Value lessThanTwo = b.create(loc, xDistanceCubed, a); + + Value fiveA = b.create(loc, xDistanceSquared, a); + fiveA = b.create(loc, fiveA, cstFiveFloat); + // a|x|^3 - 5a|x|^2 + lessThanTwo = b.create(loc, lessThanTwo, fiveA); + + Value eightA = b.create(loc, a, xDistance); + eightA = b.create(loc, eightA, cstEightFloat); + // a|x|^3 - 5a|x|^2 + 8a|x| + lessThanTwo = b.create(loc, eightA, lessThanTwo); + + Value fourA = b.create(loc, a, cstFourFloat); + // a|x|^3 - 5a|x|^2 + 8a|x| - 4a + lessThanTwo = b.create(loc, lessThanTwo, fourA); + return lessThanTwo; + }; - Value yProjMinusyLow = b.create(loc, yProj, yLow); - w1 = b.create(loc, yProjMinusyLow, yHighMinusyLow); - rhs = b.create(loc, w1, xInter1); + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - Value retVal = b.create(loc, lhs, rhs); + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } - return retVal; + SmallVector proj; + + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices, + false); + + // get the nearest neighbors of proj + Value x1 = b.create(loc, proj[1]); + Value x_1 = b.create(loc, x1, cstOneFloat); + Value x_2 = b.create(loc, x_1, cstOneFloat); + Value x2 = b.create(loc, x1, cstOneFloat); + + Value y1 = b.create(loc, proj[0]); + Value y_1 = b.create(loc, y1, cstOneFloat); + Value y_2 = b.create(loc, y_1, cstOneFloat); + Value y2 = b.create(loc, y1, cstOneFloat); + + // calculate the distance of nearest neighbors x and y to proj + Value y2Distance = b.create(loc, proj[0], y2); + y2Distance = b.create(loc, y2Distance); + Value y1Distance = b.create(loc, proj[0], y1); + y1Distance = b.create(loc, y1Distance); + Value y_1Distance = b.create(loc, proj[0], y_1); + y_1Distance = b.create(loc, y_1Distance); + Value y_2Distance = b.create(loc, proj[0], y_2); + y_2Distance = b.create(loc, y_2Distance); + + Value x2Distance = b.create(loc, proj[1], x2); + x2Distance = b.create(loc, x2Distance); + Value x1Distance = b.create(loc, proj[1], x1); + x1Distance = b.create(loc, x1Distance); + Value x_1Distance = b.create(loc, proj[1], x_1); + x_1Distance = b.create(loc, x_1Distance); + Value x_2Distance = b.create(loc, proj[1], x_2); + x_2Distance = b.create(loc, x_2Distance); + + SmallVector y{y_2, y_1, y1, y2}; + SmallVector x{x_2, x_1, x1, x2}; + + SmallVector wys{ + WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance), + WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)}; + SmallVector wxs{ + WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance), + WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)}; + + // clip the nearest neighbors points to inside the original image + for (int k = 0; k < 4; k++) { + Value yClipped = b.create(loc, y[k], zero); + Value inputHSubOne = b.create(loc, inputFPH, cstOneFloat); + yClipped = b.create(loc, yClipped, inputHSubOne); + Value yInt = b.create(loc, b.getI64Type(), yClipped); + y[k] = b.create(loc, b.getIndexType(), yInt); + + Value xClipped = b.create(loc, x[k], zero); + Value inputWSubOne = b.create(loc, inputFPW, cstOneFloat); + xClipped = b.create(loc, xClipped, inputWSubOne); + Value xInt = b.create(loc, b.getI64Type(), xClipped); + x[k] = b.create(loc, b.getIndexType(), xInt); + } + // 1. Compute x_original and y_original (proj) + // 2. Compute nearest x and y neighbors + // 3. Compute Wx Wy + // 4. Extract inputs at nearest neighbors (inputExtracts) + // 5. Compute weighted sum (yield this) + + // 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original + // 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original + // Sum_x is over 4 nearest x neighbors (similar for Sum_y) + // f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y] + // * W(y_original - y) + Value fxy = zero; + + for (int j = 0; j < 4; j++) { + Value wy = wys[j]; + Value xInterpy = zero; + + indices[dimOffset] = y[j]; + + for (int i = 0; i < 4; i++) { + Value wx = wxs[i]; + + indices[dimOffset + 1] = x[i]; + + Value p = b.create(loc, input, indices); + + Value wxp = b.create(loc, wx, p); + xInterpy = b.create(loc, xInterpy, wxp); + } + Value wyXInterpy = b.create(loc, wy, xInterpy); + fxy = b.create(loc, fxy, wyXInterpy); + } + + return fxy; } namespace { @@ -2902,57 +3145,62 @@ class ConvertInterpolateOp ConversionPatternRewriter &rewriter) const override { std::string mode; + // note: to support onnx.Resize, we are passing some extra options through + // the mode attribute. For example, onnx.Resize with mode="linear" and + // coordinate_transformation_mode="asymmetric" will lower to an interpolate + // op with the non-standard mode="bilinear_asymmetric". matchPattern(op.getMode(), m_TorchConstantStr(mode)); - if (mode != "bilinear" && mode != "nearest") { + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" && + mode.substr(0, 5) != "cubic") { return failure(); } Location loc = op->getLoc(); Value input = adaptor.getInput(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - - if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) { - return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op"); - } - - SmallVector outputSizeIntValues; - - if (!op.getScaleFactor().getType().isa()) { - SmallVector ScaleFactorTorchFloat; + if (mode.substr(0, 8) == "bilinear" && inputRank != 4) + return rewriter.notifyMatchFailure( + op, + "cannot perform bilinear interpolation when input spatial dims != 2"); + + SmallVector outputSizeIntValues; + SmallVector inputSizes; + SmallVector ScaleFactorFloatValues; + for (unsigned i = 2; i < inputRank; i++) { + Value inputSize = getDimOp(rewriter, loc, input, i); + inputSizes.push_back(rewriter.create( + loc, rewriter.getIntegerType(64), inputSize)); + } + + if (!isa(op.getScaleFactor().getType())) { + bool recompScale; + if (!matchPattern(op.getRecomputeScaleFactor(), + m_TorchConstantBool(&recompScale))) + recompScale = false; + SmallVector ScaleFactorTorchFloat; if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " "ListConstruct"); - SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); - Value inputSizeH = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputType.getShape()[2])); - Value inputHFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizeH); - Value scale = rewriter.create(loc, inputHFP.getType(), - ScaleFactorFloatValues[0]); - Value outputSizeH = rewriter.create(loc, inputHFP, scale); - Value outputH = rewriter.create(loc, outputSizeH); - outputH = - rewriter.create(loc, rewriter.getI64Type(), outputH); - - Value inputSizeW = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputType.getShape()[3])); - Value inputWFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizeW); - scale = rewriter.create(loc, inputWFP.getType(), - ScaleFactorFloatValues[1]); - Value outputSizeW = rewriter.create(loc, inputWFP, scale); - Value outputW = rewriter.create(loc, outputSizeW); - outputW = - rewriter.create(loc, rewriter.getI64Type(), outputW); - - outputSizeIntValues.push_back(outputH); - outputSizeIntValues.push_back(outputW); + for (unsigned i = 0; i < inputRank - 2; i++) { + Value inputSizeFP = rewriter.create( + loc, rewriter.getF32Type(), inputSizes[i]); + ScaleFactorFloatValues[i] = rewriter.create( + loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); + Value outputSize = rewriter.create( + loc, inputSizeFP, ScaleFactorFloatValues[i]); + outputSize = rewriter.create(loc, outputSize); + outputSize = rewriter.create( + loc, rewriter.getI64Type(), outputSize); + outputSizeIntValues.push_back(outputSize); + } + if (recompScale) + ScaleFactorFloatValues.clear(); } else { - SmallVector outputSizeTorchInt; + SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " @@ -2960,20 +3208,16 @@ class ConvertInterpolateOp outputSizeIntValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), outputSizeTorchInt); } - int hDimOffset = 2; - SmallVector dims = getTensorSizes(rewriter, loc, input); - dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]); - dims[hDimOffset + 1] = - castIntToIndex(rewriter, loc, outputSizeIntValues[1]); + SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); + for (unsigned i = 2; i < inputRank; i++) { + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2])); + } Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); - AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); - SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); - Value finalRes = rewriter .create( @@ -2981,21 +3225,26 @@ class ConvertInterpolateOp /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value outputSizeH = outputSizeIntValues[0]; - Value outputSizeW = outputSizeIntValues[1]; - Value inputSizeH = b.create( - loc, b.getI64IntegerAttr(inputType.getShape()[2])); - Value inputSizeW = b.create( - loc, b.getI64IntegerAttr(inputType.getShape()[3])); Value retVal; - if (mode == "nearest") { - retVal = - NearestInterpolate(b, loc, outputSizeH, outputSizeW, - input, inputSizeH, inputSizeW); - } else if (mode == "bilinear") { - retVal = BilinearInterpolate(b, op, loc, outputSizeH, - outputSizeW, input, inputSizeH, - inputSizeW); + if (mode.substr(0, 7) == "nearest") { + std::string coordTfMode = + mode.substr(7, mode.find(",") - 7); + std::string nearestMode = + (mode.find(",") == std::string::npos) + ? "" + : mode.substr(mode.find(",") + 1); + retVal = nearestInterpolate( + b, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, coordTfMode, nearestMode); + } else if (mode.substr(0, 8) == "bilinear") { + retVal = bilinearInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(8)); + } else if (mode.substr(0, 5) == "cubic") { + + retVal = bicubicInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(5)); } b.create(loc, retVal); }) @@ -3007,6 +3256,378 @@ class ConvertInterpolateOp } }; } // namespace + +namespace { +// This pattern row reduces a matrix, then returns the product of it's diagonal +// elements +class ConvertAtenLinalgDetOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenLinalgDetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value input = adaptor.getA(); + auto inputType = cast(input.getType()); + unsigned inputRank = inputType.getRank(); + auto elemTy = inputType.getElementType(); + bool isBatched = (inputRank == 3); + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstZeroF = getConstant(rewriter, loc, 0, elemTy); + // get some shapes + SmallVector inputShape(inputType.getShape()); + + SmallVector sliceShape(inputShape); + sliceShape[sliceShape.size() - 2] = 1; + + SmallVector diagShape(inputType.getShape()); + diagShape[diagShape.size() - 2] = 1; + diagShape[diagShape.size() - 1] = 1; + + ArrayRef diagCollapseShape(diagShape); + diagCollapseShape = diagCollapseShape.drop_back(); + + auto sliceTy = RankedTensorType::get(sliceShape, elemTy); + auto diagTy = RankedTensorType::get(diagShape, elemTy); + auto diagCollapseTy = RankedTensorType::get(diagCollapseShape, elemTy); + + SmallVector diagReassociations; + diagReassociations.reserve(diagCollapseShape.size()); + int64_t diagRank = diagCollapseShape.size(); + for (int i = 0, s = diagRank - 1; i < s; ++i) + diagReassociations.push_back(ReassociationIndices{i}); + diagReassociations.push_back(ReassociationIndices{diagRank - 1, diagRank}); + + // get some sizes + SmallVector inputSizes = getTensorSizes(rewriter, loc, input); + Value chDim = isBatched ? inputSizes[0] : cstOne; + Value matDim = inputSizes[inputRank - 1]; + Value matDimMinusOne = rewriter.create(loc, matDim, cstOne); + ArrayRef sliceSizes(inputSizes.begin(), inputSizes.end() - 1); + // initialize a tensor to store the diagonal elements found during row + // reduction + Value initDiags = rewriter.create( + loc, getAsOpFoldResult(sliceSizes), elemTy); + // loop over each pivot row in A. Get the diagonal, then reduce the + // subdiagonal Don't perform the loop on the last row since no further + // reduction is needed. + auto rowReductionLoop = rewriter.create( + loc, /*start=*/cstZero, /*end=*/matDimMinusOne, /*step=*/cstOne, + /*yeild_to=*/ValueRange{input, initDiags}, /*body_lambda=*/ + [&](OpBuilder &b, Location loc, Value row, ValueRange vals) { + // extract row i from input Tensor of shape CxNxN or shape + // NxN. + OpFoldResult cstOneFold = getAsOpFoldResult(cstOne); + OpFoldResult cstZeroFold = getAsOpFoldResult(cstZero); + SmallVector offsets(inputRank, cstZeroFold); + offsets[inputRank - 2] = row; + SmallVector strides(inputRank, cstOneFold); + auto sizes = getAsOpFoldResult(inputSizes); + sizes[inputRank - 2] = cstOneFold; + // offsets = [0, row, 0], sizes = [C, 1, N] -> pivot row + Value pivot = b.create( + loc, sliceTy, vals[0], offsets, sizes, strides); + // extract diagonal elements and insert them into vals[1] + offsets.back() = row; + sizes.back() = cstOneFold; + // offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row) + Value diag = b.create( + loc, diagTy, vals[0], offsets, sizes, strides); + + Value diagCollapse = b.create( + loc, diagCollapseTy, diag, diagReassociations); + + SmallVector diagOffsets(inputRank - 1, cstZeroFold); + diagOffsets.back() = row; + SmallVector diagStrides(inputRank - 1, cstOneFold); + SmallVector diagSizes = getAsOpFoldResult(sliceSizes); + diagSizes.back() = cstOneFold; + // offsets = [0, row], sizes = [C, 1] insert to [C,N] + Value updatedDiags = b.create( + loc, diagCollapse, vals[1], diagOffsets, diagSizes, diagStrides); + // the subpivot matrix column size, as a Value, is matDim - row - + // cstOne. This can't be statically converted to an int64_t, since row + // is the loop index, so this is left as a dynamic dim. + SmallVector subPivotShape(inputType.getShape()); + subPivotShape[inputRank - 2] = ShapedType::kDynamic; + ArrayRef subDiagShape(subPivotShape.begin(), + subPivotShape.end() - 1); + auto subPivotTy = RankedTensorType::get(subPivotShape, elemTy); + auto subDiagTy = RankedTensorType::get(subDiagShape, elemTy); + Value rowPlusOne = b.create(loc, row, cstOne); + offsets[inputRank - 2] = getAsOpFoldResult(rowPlusOne); + sizes[inputRank - 2] = getAsOpFoldResult( + b.create(loc, matDim, rowPlusOne)); + // offsets = [0, row + 1, row], sizes = [C, N - row - 1, 1] -> A_j,row + // with j > row + Value subDiag = b.create( + loc, subDiagTy, vals[0], offsets, sizes, strides); + offsets.back() = cstZeroFold; + sizes.back() = getAsOpFoldResult(matDim); + // offsets = [0, row + 1, 0], sizes = [C, N - row - 1, N] -> elements + // below pivot row + Value subPivot = b.create( + loc, subPivotTy, vals[0], offsets, sizes, strides); + Value initResult = b.create(loc, sizes, elemTy); + // write a generic op to perform subpivot = subpivot - + // (subdiag/diag)*pivot + // d0 = batches, d1 = row, d2 = column -> pivot(d0,d2), diag(d0), + // subPivot(d0,d1,d2), subDiag(d0, d1); output(d0,d1,d2) + SmallVector allDims; + for (unsigned i = 0; i < inputRank; i++) + allDims.push_back(b.getAffineDimExpr(i)); + SmallVector rowIterator(1, allDims[0]); + SmallVector colIterator; + SmallVector batchIterator; + if (isBatched) { + rowIterator.push_back(allDims[1]); + colIterator.push_back(allDims[0]); + colIterator.push_back(rewriter.getAffineConstantExpr(0)); + colIterator.push_back(allDims[2]); + batchIterator.push_back(allDims[0]); + batchIterator.push_back(getAffineConstantExpr(0, context)); + batchIterator.push_back(getAffineConstantExpr(0, context)); + } else { + colIterator.push_back(rewriter.getAffineConstantExpr(0)); + colIterator.push_back(allDims[1]); + batchIterator.push_back(getAffineConstantExpr(0, context)); + batchIterator.push_back(getAffineConstantExpr(0, context)); + } + SmallVector indexingMaps; + indexingMaps.push_back( + AffineMap::get(inputRank, 0, colIterator, context)); + indexingMaps.push_back( + AffineMap::get(inputRank, 0, batchIterator, context)); + indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank)); + indexingMaps.push_back( + AffineMap::get(inputRank, 0, rowIterator, context)); + indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank)); + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value reducedSubPivot = + b.create( + loc, subPivotTy, ValueRange{pivot, diag, subPivot, subDiag}, + initResult, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // for d0 in batches, d1 in subpivotrows, d2 in columns + // let i represent the pivot row index (scf loop index) + Value pivotd0d2 = args[0]; + Value diagd0 = args[1]; + Value subPivotd0d1d2 = args[2]; + Value subDiagd0d1 = args[3]; + // coeff = A_d1,i / A_i,i + Value coeff = + b.create(loc, subDiagd0d1, diagd0); + auto cmp = b.create( + loc, arith::CmpFPredicate::ONE, diagd0, cstZeroF); + b.create( + loc, cmp, + b.getStringAttr( + "unimplemented: determinants requiring " + "permutations and singular matrices")); + // coeff*A_i,d2 + Value scaledPivotValue = + b.create(loc, coeff, pivotd0d2); + // result = A_d1,d2 - (A_d1,i/A_i,i)*A_i,d2 + // so that when d2 = i, A_d1,i - (A_d1,i/A_i,i) * A_i,i = 0 + Value result = b.create(loc, subPivotd0d1d2, + scaledPivotValue); + b.create(loc, result); + }) + .getResult(0); + Value rowReductionResult = b.create( + loc, reducedSubPivot, vals[0], offsets, sizes, strides); + b.create(loc, + ValueRange{rowReductionResult, updatedDiags}); + }); + Value allDiagsExceptLast = rowReductionLoop.getResult(1); + SmallVector offsets(inputRank, + getAsOpFoldResult(matDimMinusOne)); + SmallVector strides(inputRank, getAsOpFoldResult(cstOne)); + SmallVector sizes(inputRank, getAsOpFoldResult(cstOne)); + sizes[0] = getAsOpFoldResult(chDim); + if (isBatched) + offsets[0] = getAsOpFoldResult(cstZero); + Value lastDiag = rewriter.create( + loc, diagTy, rowReductionLoop.getResult(0), offsets, sizes, strides); + offsets.pop_back(); + strides.pop_back(); + sizes.pop_back(); + + lastDiag = rewriter.create( + loc, diagCollapseTy, lastDiag, diagReassociations); + + Value allDiags = rewriter.create( + loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides); + // linalg generic to do reduce prod for allDiags along back dim. + // the result of that generic will be the determinant + SmallVector indexingMaps; + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(inputRank - 1)); + AffineExpr resultExpr = isBatched ? rewriter.getAffineDimExpr(0) + : getAffineConstantExpr(0, context); + indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr)); + SmallVector iteratorTypes( + inputRank - 2, utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::reduction); + Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy, + getConstant(rewriter, loc, 1.0, elemTy)); + Value determinant = + rewriter + .create( + loc, initDet.getType(), ValueRange{allDiags}, initDet, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value prod = b.create(loc, args[0], args[1]); + b.create(loc, prod); + }) + .getResult(0); + Type newResultType = + getTypeConverter()->convertType(op.getResult().getType()); + if (isBatched) { + rewriter.replaceOpWithNewOp(op, newResultType, + determinant); + return success(); + } + + determinant = rewriter.create( + loc, newResultType, determinant, + llvm::ArrayRef{}); + rewriter.replaceOp(op, ValueRange{determinant}); + return success(); + } +}; +} // namespace + +namespace { +class ConvertAtenPolarOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenPolarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op.getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + MLIRContext *context = rewriter.getContext(); + + Value absTensor = adaptor.getAbs(); + Value angleTensor = adaptor.getAngle(); + + RankedTensorType resultType = + cast(typeConverter->convertType(op.getType())); + auto elementType = resultType.getElementType(); + + SmallVector resultShape; + for (int64_t i = 0; i < resultType.getRank(); i++) { + auto currentDimSize = rewriter.create(loc, absTensor, i); + resultShape.push_back(currentDimSize); + } + + Value outTensor = rewriter.create( + loc, getAsOpFoldResult(resultShape), elementType); + + SmallVector outputExpr; + for (unsigned i = 0; i < resultType.getRank(); i++) { + outputExpr.push_back(getAffineDimExpr(i, context)); + } + + AffineMap identityMap = + AffineMap::get(resultType.getRank(), 0, outputExpr, op->getContext()); + + SmallVector indexingMaps{identityMap, identityMap, identityMap}; + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); + auto complexVar = + rewriter + .create( + loc, outTensor.getType(), ValueRange{absTensor, angleTensor}, + outTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // out = abs⋅cos(angle) + abs⋅sin(angle)⋅j + Value abs = args[0]; + Value angle = args[1]; + Value realVal = b.create(loc, angle); + Value imagVal = b.create(loc, angle); + realVal = b.create(loc, abs, realVal); + imagVal = b.create(loc, abs, imagVal); + Value complexVal = b.create( + loc, elementType, realVal, imagVal); + b.create(loc, complexVal); + }) + .getResult(0); + rewriter.replaceOpWithNewOp(op, resultType, complexVar); + return success(); + } +}; +} // namespace + +namespace { +class ConvertSymConstrainRangeOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenSymConstrainRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + auto loc = op.getLoc(); + auto min = op.getMin(); + auto max = op.getMax(); + + int64_t minValue = std::numeric_limits::min(); + int64_t maxValue = std::numeric_limits::max(); + + Type operandType = getTypeConverter()->convertType(op.getSize().getType()); + + if (!isa(min.getType())) + if (!matchPattern(min, m_TorchConstantInt(&minValue))) + return rewriter.notifyMatchFailure( + op, "Expected min value to be constant integer"); + + if (!isa(max.getType())) + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) + return rewriter.notifyMatchFailure( + op, "Expected max value to be constant integer"); + + if (maxValue < minValue) { + std::string errorMsg = + "Max must be greater than or equal to min, got min = " + + std::to_string(minValue) + ", max = " + std::to_string(maxValue); + return op.emitError(errorMsg); + } + + min = getConstant(rewriter, loc, minValue, operandType); + max = getConstant(rewriter, loc, maxValue, operandType); + + // Check min <= size <= max + + // FIXME:: Skip the below checks if constraint ops are already inserted as + // part of symbol expr evaluation + auto checkMin = rewriter.create( + loc, arith::CmpIPredicate::sle, min, adaptor.getSize()); + auto checkMax = rewriter.create( + loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); + auto compareVal = rewriter.create(loc, checkMin, checkMax); + + std::string assertMessage = "Size constraint failed. Expected range: [" + + std::to_string(minValue) + ", " + + std::to_string(maxValue) + "]"; + rewriter.create(loc, compareVal, + rewriter.getStringAttr(assertMessage)); + + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3020,20 +3641,21 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, - AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenComplexOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, + Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, + AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, + AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, + AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, - AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, - AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp, AtenIscloseOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -3064,4 +3686,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index c015ce563dd6..98dbc1957892 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -13,11 +13,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -55,7 +52,7 @@ Value torch_to_linalg::getPaddedTensor( Value torch_to_linalg::getZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &paddingInts) { - assert(input.getType().isa() && + assert(isa(input.getType()) && "input must be RankedTensorType"); Location loc = op->getLoc(); Value c0 = b.create( @@ -70,31 +67,25 @@ Value torch_to_linalg::getZeroPaddedTensor( Value torch_to_linalg::getDynamicZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, int unpaddedDims, Value pad) { - assert(input.getType().isa() && + assert(isa(input.getType()) && "input must be RankedTensorType"); - unsigned int inRank = cast(input.getType()).getRank(); Location loc = op->getLoc(); SmallVector inputDims = getTensorSizes(b, loc, input); Value c0 = b.create(loc, b.getI64IntegerAttr(0)); SmallVector paddingIncludingUnchanged(unpaddedDims, c0); paddingIncludingUnchanged.append(padding); - assert(unpaddedDims + padding.size() == inRank && + assert(static_cast(unpaddedDims + padding.size()) == + cast(input.getType()).getRank() && "sum of unpaddedDims and padding.size() must equal to inputRank"); for (auto pad = paddingIncludingUnchanged.begin(); pad < paddingIncludingUnchanged.end(); pad++) *pad = castIntToIndex(b, loc, *pad); - Type elementType = cast(input.getType()).getElementType(); - // TODO: audit possibility of sparsity on this tensor - Type inputType = - RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( - SmallVector(inRank, kUnknownSize))), - elementType); - SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); - return b.create(loc, inputType, input, /*low=*/paddingValues, + + return b.create(loc, Type{}, input, /*low=*/paddingValues, /*high=*/paddingValues, pad); } @@ -106,25 +97,41 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value c1 = b.create(loc, b.getI64IntegerAttr(1)); Value c2 = b.create(loc, b.getI64IntegerAttr(2)); - Value doublePadding = b.create(loc, paddingInt, c2); + Value doublePadding = b.createOrFold(loc, paddingInt, c2); // in + 2 * padding - Value inAddDoublePadding = - b.create(loc, castIndexToInt64(b, loc, in), doublePadding); + Value inAddDoublePadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), doublePadding); // dilation * (kernelSize - 1) - Value kernelSizeSub1 = b.create(loc, kernelSizeInt, c1); + Value kernelSizeSub1 = b.createOrFold(loc, kernelSizeInt, c1); Value dilationTimesKernelSize = - b.create(loc, dilationInt, kernelSizeSub1); + b.createOrFold(loc, dilationInt, kernelSizeSub1); - Value temp = - b.create(loc, inAddDoublePadding, dilationTimesKernelSize); - Value dividend = b.create(loc, temp, c1); + Value temp = b.createOrFold(loc, inAddDoublePadding, + dilationTimesKernelSize); + Value dividend = b.createOrFold(loc, temp, c1); Value division; if (ceilMode) - division = b.create(loc, dividend, strideInt); + division = b.createOrFold(loc, dividend, strideInt); else - division = b.create(loc, dividend, strideInt); - Value out = b.create(loc, division, c1); + division = b.createOrFold(loc, dividend, strideInt); + Value out = b.createOrFold(loc, division, c1); + + if (ceilMode) { + Value outMinusOneTimesStride = + b.createOrFold(loc, division, strideInt); + Value inAddLeftPadding = b.createOrFold( + loc, castIndexToInt64(b, loc, in), paddingInt); + + auto reduceOutputDimCond = + b.createOrFold(loc, arith::CmpIPredicate::uge, + outMinusOneTimesStride, inAddLeftPadding); + + auto reducedDim = b.createOrFold(loc, reduceOutputDimCond, + division, out); + return castIntToIndex(b, loc, reducedDim); + } + return castIntToIndex(b, loc, out); } @@ -568,6 +575,8 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) { return false; if (isa(type)) return true; + if (isa(type)) + return false; if (isa(type)) return false; if (auto intTy = dyn_cast(type)) @@ -575,3 +584,84 @@ bool torch_to_linalg::isUnsignedTorchType(Type type) { llvm_unreachable("Unknown type checked for signedness"); return false; } + +LogicalResult torch_to_linalg::permuteTensor(Operation *op, + PatternRewriter &rewriter, + Location loc, + SmallVector dimensions, + Value input, Value &result) { + auto inType = cast(input.getType()); + int64_t inputRank = inType.getRank(); + Type elementType = inType.getElementType(); + + // Check for 0-D tensor. + if (inputRank == 0) { + result = input; + return success(); + } + + // Check if the dimensions are a valid constants. + int64_t numDimensions = dimensions.size(); + if (inputRank != numDimensions) + return rewriter.notifyMatchFailure( + op, "size of `dims` must be equal to the rank of the input"); + for (uint32_t i = 0; i < numDimensions; i++) { + if (dimensions[i] < 0) + dimensions[i] = toPositiveDim(dimensions[i], inputRank); + if (!isValidDim(dimensions[i], inputRank)) + return rewriter.notifyMatchFailure(op, "dimension out of range"); + } + + SmallVector outputDims; + for (uint32_t i = 0; i < inputRank; i++) + outputDims.push_back(getDimOp(rewriter, loc, input, dimensions[i])); + + Value outVector = rewriter.create( + loc, getAsOpFoldResult(outputDims), elementType); + + result = + rewriter.create(loc, input, outVector, dimensions) + ->getResult(0); + return success(); +} + +// Flips an input tensor based on the values of axis list. +Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, + Value input, SmallVector axis) { + Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); + Type elementType = cast(input.getType()).getElementType(); + auto selfRank = cast(input.getType()).getRank(); + + // Only used to calculate flipped values, i.e. those on the flip axes. Other + // dims won't be used. + SmallVector dims = getTensorSizes(rewriter, loc, input); + for (auto flipDim : axis) + dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); + + Value initTensor = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, input), elementType); + + SmallVector iteratorTypes(selfRank, + utils::IteratorType::parallel); + SmallVector indexingMaps( + 2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext())); + Value flipped = + rewriter + .create( + loc, input.getType(), input, initTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (auto i = 0; i < selfRank; i++) + indices.push_back(b.create(loc, i)); + for (auto flipDim : axis) { + indices[flipDim] = b.create(loc, dims[flipDim], + indices[flipDim]); + } + Value res = b.create(loc, input, indices) + .getResult(); + b.create(loc, res); + }) + .getResult(0); + return flipped; +} diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 60206f03999b..27e0a61f4b31 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -12,12 +12,10 @@ #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; @@ -254,7 +252,7 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { // "block" arguments for (const auto &barg : enumerate(op.getRegion().front().getArguments())) { Value to = block->getArgument(barg.index()); - if (to.getType().isa()) + if (isa(to.getType())) to = rewriter.create(loc, rewriter.getI64Type(), to); Type targetType = to.getType(); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 377795d843d9..d6ba57a08a8f 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -14,14 +14,12 @@ #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" @@ -36,37 +34,8 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; -namespace { - -template -static Value getConstantLike(OpBuilder &b, Location loc, T constant, - Value val) { - Type ty = getElementTypeOrSelf(val.getType()); - auto getAttr = [&]() -> Attribute { - if (isa(ty)) - return b.getIntegerAttr(ty, constant); - if (isa(ty)) - return b.getFloatAttr(ty, constant); - if (auto complexTy = dyn_cast(ty)) - return complex::NumberAttr::get(complexTy, constant, 0); - llvm_unreachable("unhandled element type"); - }; - return b.create(loc, cast(getAttr()), - val); -} - -Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, - Value val) { - Type ty = getElementTypeOrSelf(val.getType()); - return b.create(loc, b.getFloatAttr(ty, constant), - val); -} - -} // namespace - LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, - mlir::Value &self, mlir::Value &other, - size_t dimSizeIndexBits) { + mlir::Value &self, mlir::Value &other) { auto selfTy = dyn_cast(self.getType()); auto otherTy = dyn_cast(other.getType()); auto selfRank = selfTy.getRank(); @@ -76,16 +45,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); - auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, - unsqueezeDims, dimSizeIndexBits); + auto unsqueezeInfo = + hlo::unsqueezeTensor(rewriter, op, other, unsqueezeDims); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); - auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, - dimSizeIndexBits); + auto unsqueezeInfo = + hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; @@ -176,10 +145,11 @@ class ConvertAtenUnaryOp : public OpConversionPattern { if (!selfType) { return op.emitError("only Tensor types supported in StableHLO"); } - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - self = hlo::promoteType(rewriter, op.getLoc(), self, outType); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + self = + hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType()); rewriter.replaceOpWithNewOp(op, outType, self); return success(); } @@ -233,12 +203,13 @@ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { auto selfTy = cast(self.getType()); if (!selfTy) return op.emitError("only Tensor types supported in StableHLO"); - auto resultTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (isa(resultTy.getElementType())) { - Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy); + Value src = hlo::promoteType(rewriter, op.getLoc(), self, + resultTy.getElementType()); rewriter.replaceOpWithNewOp(op, resultTy, src); return success(); } else { @@ -261,9 +232,9 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType) return op.emitError("only Tensor types supported in StableHLO"); @@ -307,8 +278,8 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto inputType = dyn_cast(adaptor.getA().getType()); if (!inputType) - op.emitError("only Tensor types supported in StableHLO"); + Location loc = op.getLoc(); Value input = adaptor.getA(); SmallVector inputSizes = getTensorSizes(rewriter, loc, input); @@ -320,14 +291,24 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); + // handle unsigned interger + if (inputType.getElementType().isUnsignedInteger()) { + input = rewriter.create( + loc, input, + rewriter.getIntegerType( + inputType.getElementType().getIntOrFloatBitWidth())); + } + Value constantZero = rewriter.create(loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); Value result = rewriter.create(loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, - resultType, inputDtype)); + rewriter.replaceOp( + op, + convertScalarToDtype(rewriter, loc, result, resultType, inputDtype, + /*srcOriginalDtype=*/inputType.getElementType())); return success(); } }; @@ -351,12 +332,12 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { if (!lhsTy || !rhsTy) return op.emitError("only Tensor types supported"); - auto outTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy.getElementType()); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy.getElementType()); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -384,9 +365,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (!lhsType) return op.emitError("only Tensor types supported in StableHLO"); - TensorType outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + TensorType outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { @@ -402,8 +383,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } } - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, @@ -458,8 +439,8 @@ class ConvertAtenMulDivOp : public OpConversionPattern { } } DenseI64ArrayAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -497,16 +478,17 @@ class ConvertAtenMulDivOp : public OpConversionPattern { if (isa(outElemTy)) result = rewriter.create(loc, result).getResult(); else if (!outElemTy.isUnsignedInteger()) { - TensorType defaultIntToFloatType = - outType.cloneWith(outType.getShape(), rewriter.getF64Type()); + Type defaultIntToFloatType = rewriter.getF64Type(); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType); - result = rewriter.create(loc, defaultIntToFloatType, lhs, rhs, - bcastDimensions); + result = rewriter.create( + loc, outType.cloneWith(outType.getShape(), defaultIntToFloatType), + lhs, rhs, bcastDimensions); result = rewriter.create(loc, result).getResult(); - result = hlo::promoteType(rewriter, op.getLoc(), result, outType); + result = hlo::promoteType(rewriter, op.getLoc(), result, + outType.getElementType()); } } rewriter.replaceOp(op, result); @@ -534,10 +516,12 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (!lhsTy) { return op.emitError("only Tensor types supported in StableHLO"); } + bool isRhsScalar = false; if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); rhsTy = dyn_cast(rhs.getType()); + isRhsScalar = true; } auto outType = cast( @@ -552,16 +536,28 @@ class ConvertAtenCompareOp : public OpConversionPattern { } if (isa(lhsElemTy) && isa(rhsElemTy)) { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy); + // torch.lt(x_int, 1.1) use fp32 as compute type + // torch.lt(x_int, y_float) use y's float type as compute type + Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } else if (isa(lhsElemTy) && isa(rhsElemTy)) { - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); + // always use lhs's float type as compute type + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - if (lhsElemTy.getIntOrFloatBitWidth() > - rhsElemTy.getIntOrFloatBitWidth()) { - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy); + if (isRhsScalar) { + // torch.lt(x_float, 1.1) use x's float type as compute type + // torch.lt(x_int, 1) use x's int type as compute type + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsTy); + // torch.lt(x_float, y_float) use higher bitwidth as compute type + Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() > + rhsElemTy.getIntOrFloatBitWidth() + ? lhsElemTy + : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } } lhsElemTy = dyn_cast(lhs.getType()).getElementType(); @@ -637,15 +633,15 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { if (!lhsTy) return op.emitError("lhs must be a ranked tensor type"); - TensorType outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + TensorType outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, @@ -755,15 +751,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = cast(getTypeConverter()->convertType(op.getType())); // promote self and other types - self = hlo::promoteType(rewriter, op.getLoc(), self, outType); - other = hlo::promoteType(rewriter, op.getLoc(), other, outType); + self = + hlo::promoteType(rewriter, op.getLoc(), self, outType.getElementType()); + other = + hlo::promoteType(rewriter, op.getLoc(), other, outType.getElementType()); - if (failed( - broadcastRanks(rewriter, op, self, cond, options.dimSizeIndexBits))) + if (failed(broadcastRanks(rewriter, op, self, cond))) return op.emitError("failed broadcast self and condition ranks"); - if (failed( - broadcastRanks(rewriter, op, other, cond, options.dimSizeIndexBits))) + if (failed(broadcastRanks(rewriter, op, other, cond))) return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( @@ -783,7 +779,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( getTypeConverter()->convertType(op->getResult(0).getType())); if (options.enableStaticShape && selfTy.hasStaticShape()) { - Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); + Value bcastOp = + hlo::promoteAndBroadcast(rewriter, self, outType, std::nullopt); rewriter.replaceOp(op, bcastOp); return success(); } @@ -928,84 +925,55 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "for AtenReciprocalOp"); } - Value oneTensor = getConstantLike(rewriter, op->getLoc(), 1, input); + Value oneTensor = + hlo::getConstantLike(rewriter, op->getLoc(), 1, input); rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } -// AtenPowTensorScalarOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsType = dyn_cast(lhs.getType()); - Value rhs = adaptor.getExponent(); - TensorType rhsType = dyn_cast(rhs.getType()); - - if (!lhsType) - return op.emitError("only Tensor types supported in StableHLO"); - - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - - if (!rhsType) { - rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); - } - DenseI64ArrayAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); - auto loc = op.getLoc(); - Value result = rewriter.create(loc, outType, lhs, rhs, - bcastDimensions); - - rewriter.replaceOp(op, result); - return success(); -} - -// AtenPowScalarOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsType = dyn_cast(lhs.getType()); - Value rhs = adaptor.getExponent(); - auto rhsType = dyn_cast(rhs.getType()); +namespace { +template +class ConvertAtenPowOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); - if (!rhsType) - return op.emitError("only Tensor types supported in StableHLO"); + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } - auto outType = cast( - OpConversionPattern::getTypeConverter()->convertType( - op.getType())); + Value lhs = adaptor.getSelf(); + auto lhsType = dyn_cast(lhs.getType()); + Value rhs = adaptor.getExponent(); + auto rhsType = dyn_cast(rhs.getType()); - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } + if (!lhsType && !rhsType) { + return op.emitError("only Tensor types supported in StableHLO"); + } + if (!lhsType) { + lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); + } + if (!rhsType) { + rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); + } - if (!lhsType) { - lhs = hlo::scalarToStablehloTensor(rewriter, op, lhs, outElemTy); + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outElemTy); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); + DenseI64ArrayAttr bcastDimensions; + rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, + bcastDimensions); + return success(); } - DenseI64ArrayAttr bcastDimensions; - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); - auto loc = op.getLoc(); - Value result = rewriter.create(loc, outType, lhs, rhs, - bcastDimensions); - - rewriter.replaceOp(op, result); - return success(); -} +}; +} // namespace // PrimNumToTensorScalarOp template <> @@ -1070,12 +1038,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return op->emitError("only float tensor in relu op is supported"); } - Value zeroTensor; - zeroTensor = getConstantLike( - rewriter, op->getLoc(), - APFloat::getZero(cast(lhsElemTy).getFloatSemantics(), - false), - lhs); + Value zeroTensor = + hlo::getConstantLike(rewriter, op->getLoc(), 0, lhs); rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } @@ -1102,13 +1066,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return op.emitError("unsupported approximate: ") << approximate; } - Value one = getConstantLike(rewriter, loc, 1.0, input); - Value two = getConstantLike(rewriter, loc, 2.0, input); - Value three = getConstantLike(rewriter, loc, 3.0, input); - Value half = getConstantLike(rewriter, loc, 0.5, input); + Value one = hlo::getConstantLike(rewriter, loc, 1.0, input); + Value two = hlo::getConstantLike(rewriter, loc, 2.0, input); + Value three = hlo::getConstantLike(rewriter, loc, 3.0, input); + Value half = hlo::getConstantLike(rewriter, loc, 0.5, input); // 2/pi - Value twoDivPi = getConstantLike(rewriter, loc, M_2_PI, input); - Value t = getConstantLike(rewriter, loc, 0.044715, input); + Value twoDivPi = hlo::getConstantLike(rewriter, loc, M_2_PI, input); + Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input); // x * 0.5 auto inputMulHalf = rewriter.create(loc, input, half); @@ -1145,9 +1109,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return op.emitError("only ranked tensor type is supported."); } auto outTy = cast(getTypeConverter()->convertType(op.getType())); - input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + input = + hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType()); - auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input); + auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input); auto log2Op = rewriter.create(op.getLoc(), two); auto logInputOp = rewriter.create(op.getLoc(), input); @@ -1167,9 +1132,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto outTy = cast(getTypeConverter()->convertType(op.getType())); - input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + input = + hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType()); - auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input); + auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input); auto log10Op = rewriter.create(op.getLoc(), ten); auto logInputOp = rewriter.create(op.getLoc(), input); @@ -1177,6 +1143,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenLogitOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLogitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + Value self = adaptor.getSelf(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) { + return op.emitError("only ranked tensor type is supported."); + } + + auto outTy = cast(getTypeConverter()->convertType(op.getType())); + self = hlo::promoteType(rewriter, op.getLoc(), self, outTy.getElementType()); + + selfTy = dyn_cast(self.getType()); + + Value eps = adaptor.getEps(); + auto epsTy = eps.getType(); + Value newSelf; + if (!isa(epsTy)) { + auto epsTensor = hlo::scalarToStablehloTensor(rewriter, op, eps, + selfTy.getElementType()); + Value oneEpsTensor = hlo::getConstantLike(rewriter, loc, 1.0, epsTensor); + auto max = + rewriter.create(loc, oneEpsTensor, epsTensor); + newSelf = rewriter.create(loc, epsTensor, self, max); + } else { + newSelf = self; + } + + Value one = hlo::getConstantLike(rewriter, loc, 1.0, self); + Value zi1 = rewriter.create(loc, one, newSelf); + Value newZi = rewriter.create(loc, newSelf, zi1); + + Value log = rewriter.create(loc, outTy, newZi); + + rewriter.replaceOp(op, log); + + return success(); +} + // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1290,42 +1299,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "non-bool cudnn_enabled unsupported"); } if (training) { - Type outputTy = getTypeConverter()->convertType(op.getType()); - Type batchMeanOrVarTy = - RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); + TensorType outputTy = + cast(getTypeConverter()->convertType(op.getType())); Value output; // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { - RankedTensorType convertedType = inputTy; - if (cast(weightTy.getElementType()).getWidth() > - cast(inputTy.getElementType()).getWidth()) { - convertedType = RankedTensorType::get(inputTy.getShape(), - weightTy.getElementType()); + Type computeType = inputTy.getElementType(); + if (weightTy.getElementType().getIntOrFloatBitWidth() > + inputTy.getElementType().getIntOrFloatBitWidth()) { + computeType = weightTy.getElementType(); } - input = hlo::promoteType(rewriter, op.getLoc(), input, convertedType); - weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); - bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); + input = hlo::promoteType(rewriter, op.getLoc(), input, computeType); + weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType); + bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType); auto batchNormTrainingResult = rewriter.create( - op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + op.getLoc(), + RankedTensorType::get(inputTy.getShape(), computeType), + RankedTensorType::get(weightTy.getShape(), computeType), + RankedTensorType::get(weightTy.getShape(), computeType), input, weight, bias, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), batchNormTrainingResult.getResult(0), - cast(outputTy)); + outputTy.getElementType()); } else { auto batchNormTrainingResult = rewriter.create( - op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, - weight, bias, rewriter.getF32FloatAttr(eps), + op.getLoc(), outputTy, weightTy, weightTy, input, weight, bias, + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = batchNormTrainingResult.getResult(0); } rewriter.replaceOp(op, output); return success(); } else { - Type outputTy = getTypeConverter()->convertType(op.getType()); + TensorType outputTy = + cast(getTypeConverter()->convertType(op.getType())); SmallVector castShape{inputTy.getShape().begin(), inputTy.getShape().end()}; castShape[1] = weightTy.getShape()[0]; @@ -1338,26 +1349,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value output; // supported mixed types, like input type is fp16 and weight type is fp32. if (inputTy.getElementType() != weightTy.getElementType()) { - RankedTensorType convertedType = inputTy; - if (cast(weightTy.getElementType()).getWidth() > - cast(inputTy.getElementType()).getWidth()) { - convertedType = RankedTensorType::get(inputTy.getShape(), - weightTy.getElementType()); + Type computeType = inputTy.getElementType(); + if (weightTy.getElementType().getIntOrFloatBitWidth() > + inputTy.getElementType().getIntOrFloatBitWidth()) { + computeType = weightTy.getElementType(); } - input = - hlo::promoteType(rewriter, op.getLoc(), inputCasted, convertedType); - weight = hlo::promoteType(rewriter, op.getLoc(), weight, convertedType); - bias = hlo::promoteType(rewriter, op.getLoc(), bias, convertedType); + input = hlo::promoteType(rewriter, op.getLoc(), inputCasted, computeType); + weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType); + bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType); runningMean = - hlo::promoteType(rewriter, op.getLoc(), runningMean, convertedType); + hlo::promoteType(rewriter, op.getLoc(), runningMean, computeType); runningVar = - hlo::promoteType(rewriter, op.getLoc(), runningVar, convertedType); + hlo::promoteType(rewriter, op.getLoc(), runningVar, computeType); Value bnResult = rewriter.create( - op.getLoc(), convertedType, input, weight, bias, runningMean, - runningVar, rewriter.getF32FloatAttr(eps), + op.getLoc(), RankedTensorType::get(inputTy.getShape(), computeType), + input, weight, bias, runningMean, runningVar, + rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), bnResult, - cast(outputTy)); + outputTy.getElementType()); } else { output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, @@ -1454,9 +1464,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. SmallVector zeroConstVec( - numFeatureDimSize, APFloat::getZero(inputTy.getElementType() - .cast() - .getFloatSemantics())); + numFeatureDimSize, + APFloat::getZero( + cast(inputTy.getElementType()).getFloatSemantics())); SmallVector oneConstVec( numFeatureDimSize, APFloat( @@ -1503,8 +1513,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value()); // Apply affine transform: output x weight + bias [element-wise] - auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); - auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); + auto bcastedWeight = + hlo::promoteAndBroadcast(rewriter, weight, outputTy, std::nullopt); + auto bcastedBias = + hlo::promoteAndBroadcast(rewriter, bias, outputTy, std::nullopt); auto outputMulWeight = rewriter.create(op->getLoc(), output, bcastedWeight); auto finalOuput = rewriter.create( @@ -1539,7 +1551,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Promote type for (auto &v : builtinTensors) { - v = hlo::promoteType(rewriter, op->getLoc(), v, outType); + v = hlo::promoteType(rewriter, op->getLoc(), v, outType.getElementType()); } rewriter.replaceOpWithNewOp( @@ -1649,8 +1661,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( maxValue = *maxInfo; } if (inputType.hasStaticShape()) { - minValue = hlo::promoteAndBroadcast(rewriter, minValue, inputType); - maxValue = hlo::promoteAndBroadcast(rewriter, maxValue, inputType); + minValue = + hlo::promoteAndBroadcast(rewriter, minValue, inputType, std::nullopt); + maxValue = + hlo::promoteAndBroadcast(rewriter, maxValue, inputType, std::nullopt); } rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); @@ -1666,9 +1680,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Location loc = op->getLoc(); // Get element type of resultType as dtype - auto outType = this->getTypeConverter() - ->convertType(op.getType()) - .cast(); + auto outType = cast( + this->getTypeConverter()->convertType(op.getType())); auto dtype = outType.getElementType(); if (!isa(dtype) && !isa(dtype)) { return rewriter.notifyMatchFailure( @@ -1711,7 +1724,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenConstantPadNdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = self.getType().cast(); + auto selfTy = cast(self.getType()); auto selfElemTy = selfTy.getElementType(); int64_t rank = selfTy.getRank(); @@ -1764,12 +1777,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } // Create constant value - Value kAlpha = getConstantLike(rewriter, loc, 0.70710678118654752440, input); + Value kAlpha = + hlo::getConstantLike(rewriter, loc, 0.70710678118654752440, input); Value cstAlpha0 = - getConstantLike(rewriter, loc, 1.12837916709551257390, input); - Value half = getConstantLike(rewriter, loc, .5, input); - Value one = getConstantLike(rewriter, loc, 1.0, input); - Value negHalf = getConstantLike(rewriter, loc, -0.5, input); + hlo::getConstantLike(rewriter, loc, 1.12837916709551257390, input); + Value half = hlo::getConstantLike(rewriter, loc, .5, input); + Value one = hlo::getConstantLike(rewriter, loc, 1.0, input); + Value negHalf = hlo::getConstantLike(rewriter, loc, -0.5, input); // Compute Value kBeta0 = @@ -1796,29 +1810,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value lhs = adaptor.getSelf(); - auto lhsTy = cast(lhs.getType()); - Value rhs = adaptor.getExponent(); - auto rhsTy = cast(rhs.getType()); - - if (!lhsTy || !rhsTy) - return op.emitError("only Tensor types supported"); - - auto outTy = - cast(this->getTypeConverter()->convertType(op.getType())); - - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy); - rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy); - - rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, - /*broadcast_attr*/ nullptr); - return success(); -} - // Converts `aten.empty.memory_format` to `tensor.empty` op. template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1985,8 +1976,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultType = cast(getTypeConverter()->convertType(op.getType())); - lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); - rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); + lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, + resultType.getElementType()); + rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, + resultType.getElementType()); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } @@ -2003,8 +1996,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultType = cast(getTypeConverter()->convertType(op.getType())); - lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, resultType); - rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, resultType); + lhs = hlo::promoteType(rewriter, op->getLoc(), lhs, + resultType.getElementType()); + rhs = hlo::promoteType(rewriter, op->getLoc(), rhs, + resultType.getElementType()); stablehlo::MulOp mul; auto div = rewriter.create(loc, lhs, rhs); @@ -2032,7 +2027,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultType = cast(getTypeConverter()->convertType(op.getType())); - rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); + rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } @@ -2047,11 +2042,106 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto resultType = cast(getTypeConverter()->convertType(op.getType())); - rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType); + rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt); rewriter.replaceOpWithNewOp(op, lhs, rhs); return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTrilOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Location loc = op.getLoc(); + + Value self = adaptor.getSelf(); + + auto selfTy = cast(self.getType()); + if (!selfTy.hasStaticShape()) { + return op->emitError("dynamic shaped input is not supported"); + } + + ArrayRef selfShape = selfTy.getShape(); + int64_t selfRank = selfTy.getRank(); + auto iotaElementTy = mlir::IntegerType::get(op.getContext(), 64); + auto iotaTy = RankedTensorType::get( + {selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy); + Value colIdxTensor = + rewriter.create(loc, iotaTy, 1).getResult(); + Value rowIdxTensor = + rewriter.create(loc, iotaTy, 0).getResult(); + + Value diagonal = adaptor.getDiagonal(); + Value diagonalTensor = + rewriter.create(loc, diagonal).getResult(); + + auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1}); + Value shiftedRowIdxTensor = rewriter.create( + loc, rowIdxTensor, diagonalTensor, bcastDimensions); + + auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::LE); + auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::SIGNED); + auto cmpTy = iotaTy.clone(rewriter.getI1Type()); + Value cmpRes = rewriter.create( + loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr, + cmpTypeAttr); + + auto resTy = + cast(getTypeConverter()->convertType(op.getType())); + + auto bcastTy = resTy.clone(rewriter.getI1Type()); + auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1}); + Value bcastedCmpRes = rewriter.create( + loc, bcastTy, cmpRes, bcastAttr); + + auto resElemTy = resTy.getElementType(); + Value zeroTensor; + if (isa(resElemTy)) { + auto constAttr = SplatElementsAttr::get( + resTy, llvm::APFloat::getZero( + cast(resElemTy).getFloatSemantics(), false)); + zeroTensor = rewriter.create(loc, resTy, constAttr); + } else if (isa(resElemTy)) { + auto constAttr = SplatElementsAttr::get( + resTy, + llvm::APInt::getZero(cast(resElemTy).getWidth())); + zeroTensor = rewriter.create(loc, resTy, constAttr); + } else { + return op.emitError("element type is not float or integer"); + } + + rewriter.replaceOpWithNewOp( + op.getOperation(), resTy, bcastedCmpRes, self, zeroTensor); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIsfiniteOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only Tensor types are currently supported"); + + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + Type outElemTy = outType.getElementType(); + if (!outElemTy.isInteger(1)) { + return rewriter.notifyMatchFailure( + op, "Only i1 output element type is supported"); + } + + rewriter.replaceOpWithNewOp(op.getOperation(), outType, + self); + + return success(); +} + void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, const TorchToStablehloOptions &options) { @@ -2174,6 +2264,14 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #undef INSERT_BINARY_LOGICAL_PATTERN +#define INSERT_BINARY_POW_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context) + INSERT_BINARY_POW_PATTERN(AtenPowTensorScalarOp); + INSERT_BINARY_POW_PATTERN(AtenPowTensorTensorOp); + INSERT_BINARY_POW_PATTERN(AtenPowScalarOp); +#undef INSERT_BINARY_ADDSUB_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) @@ -2184,8 +2282,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); INSERT_ATENOP_PATTERN(AtenTensorIntOp); INSERT_ATENOP_PATTERN(AtenReciprocalOp); - INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenPowScalarOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenScalarImplicitOp); INSERT_ATENOP_PATTERN(AtenContiguousOp); @@ -2195,6 +2291,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenLogitOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); @@ -2209,7 +2306,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenSizeIntOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); INSERT_ATENOP_PATTERN(AtenFillScalarOp); @@ -2218,6 +2314,9 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenFmodTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp); INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp); + + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenIsfiniteOp); #undef INSERT_ATENOP_PATTERN #define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 00c022cc1067..c7a67abebab5 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -13,15 +13,15 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; @@ -32,6 +32,9 @@ namespace { static Value createInitialValueForGatherScatterOp(Operation *op, RankedTensorType constType, PatternRewriter &rewriter) { + if (!constType.hasStaticShape()) { + return nullptr; + } auto elementTy = constType.getElementType(); if (isa(op)) { if (isa(elementTy)) { @@ -101,6 +104,8 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); @@ -154,8 +159,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value builtinTypeStart = adaptor.getStart(); Value builtinTypeEnd = adaptor.getEnd(); - if (torchTypeStart.getType().isa() || - torchTypeEnd.getType().isa()) + if (isa(torchTypeStart.getType()) || + isa(torchTypeEnd.getType())) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); int64_t step; @@ -215,49 +220,62 @@ namespace { FailureOr broadcastAndConcatIndices(Operation *op, ConversionPatternRewriter &rewriter, SmallVector indexTensors, - llvm::ArrayRef inputShape, + size_t dimSizeIndexBits, int &maxIndexRank) { // Step 1: broadcast indices tensors - SmallVector indicesShape; - SmallVector expandShape; - SmallVector concatShape; + bool allIndexStaticShape = true; + // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto indexTensor = indexTensors[i]; auto indexTensorType = cast(indexTensor.getType()); for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) { if (size == kUnknownSize) - return failure(); + allIndexStaticShape = false; } maxIndexRank = std::max(maxIndexRank, (int)indexTensorType.getRank()); } - SmallVector refinedInputShape = makeShapeTorchCompatible(inputShape); - for (int64_t size : refinedInputShape) { - if (size == kUnknownSize) { - return failure(); - } - } - for (int i = 0; i < maxIndexRank; i++) { - indicesShape.push_back(refinedInputShape[i]); - expandShape.push_back(refinedInputShape[i]); - concatShape.push_back(refinedInputShape[i]); + auto bcastSizeInfo = hlo::getBroadcastResultShape(rewriter, op, indexTensors, + dimSizeIndexBits); + if (failed(bcastSizeInfo)) { + return failure(); } + Value bcastSizeTensor = (*bcastSizeInfo).first; + auto indicesShape = (*bcastSizeInfo).second; + SmallVector expandShape(indicesShape.begin(), indicesShape.end()); + SmallVector concatShape(indicesShape.begin(), indicesShape.end()); expandShape.push_back(1); concatShape.push_back(indexTensors.size()); SmallVector broadcastedIndices; - Type indexElemTy = - cast(indexTensors[0].getType()).getElementType(); + Type indexElemTy = rewriter.getI64Type(); RankedTensorType bcastIndexType = RankedTensorType::get(indicesShape, indexElemTy); for (auto indexTensor : indexTensors) { - Value bcastVal = - hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType); + Value bcastVal; RankedTensorType reshapeType = RankedTensorType::get(expandShape, indexElemTy); - bcastVal = rewriter.create(op->getLoc(), reshapeType, - bcastVal); + if (allIndexStaticShape) { + bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, + std::nullopt); + bcastVal = rewriter.create(op->getLoc(), + reshapeType, bcastVal); + } else { + bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, + bcastSizeTensor); + auto bcastValShapeTensorVec = + *hlo::getDimSizesOfTensor(rewriter, op, bcastVal, dimSizeIndexBits); + bcastValShapeTensorVec.push_back(rewriter.create( + op->getLoc(), rewriter.getIntegerAttr( + rewriter.getIntegerType(dimSizeIndexBits), 1))); + Value bcastValShapeTensor = rewriter + .create( + op->getLoc(), bcastValShapeTensorVec) + .getResult(); + bcastVal = rewriter.create( + op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor); + } broadcastedIndices.push_back(bcastVal); } @@ -347,11 +365,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "offsets must be a vector with static shape equal to 1"); - if (!op.getPaddingIdx().getType().isa()) + if (!isa(op.getPaddingIdx().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented: padding_idx should be none"); - if (!op.getPerSampleWeights().getType().isa()) + if (!isa(op.getPerSampleWeights().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented: per_sample_weights should be none"); @@ -434,16 +452,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), addResult); } - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, weight, options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, weight); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); outShapeVec[0] = one; auto outShapeTensor = rewriter.create(op->getLoc(), outShapeVec); @@ -451,25 +467,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( loc, getTypeConverter()->convertType(op.getType(0)), stablehloReduceOp.getResult(0), outShapeTensor); - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(1).getType()) - .cast(); + RankedTensorType resultType = cast( + getTypeConverter()->convertType(op->getResult(1).getType())); Value resultB = createInitialValueForGatherScatterOp(op, resultType, rewriter); if (!resultB) return failure(); - resultType = getTypeConverter() - ->convertType(op->getResult(2).getType()) - .cast(); + resultType = cast( + getTypeConverter()->convertType(op->getResult(2).getType())); Value resultC = createInitialValueForGatherScatterOp(op, resultType, rewriter); if (!resultC) return failure(); - resultType = getTypeConverter() - ->convertType(op->getResult(3).getType()) - .cast(); + resultType = cast( + getTypeConverter()->convertType(op->getResult(3).getType())); Value resultD = createInitialValueForGatherScatterOp(op, resultType, rewriter); if (!resultD) @@ -536,16 +549,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "only constant boolean `sparse_grad` param supported"); } - auto options = getOptions(); - auto indexShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } - auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); auto toConcatIndexShapeValueVec = *indexShapeInfo; toConcatIndexShapeValueVec.push_back(one); auto toConcatIndexShape = @@ -585,6 +595,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getContext(), /*offsetDims=*/{}, /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); @@ -607,33 +619,100 @@ LogicalResult ConvertAtenOp::matchAndRewrite( const TypeConverter *typeConverter = getTypeConverter(); auto input = adaptor.getSelf(); + RankedTensorType inputType = cast(input.getType()); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); - SmallVector resultShape; - SmallVector offsets; - SmallVector strides; - if (failed(prepareArgumentsForSlicingOp( - op, adaptor, rewriter, resultShape, offsets, strides))) { - return failure(); + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return op->emitError("unimplemented: dim is not constant"); + } + + int64_t inputRank = inputType.getRank(); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + } + + auto inputShape = inputType.getShape(); + auto dimSize = inputShape[dim]; + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { + return op->emitError("unimplemented: step is not constant"); + } + + int64_t start; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) { + return op->emitError("unimplemented: start is not constant"); + } else if (ShapedType::isDynamic(dimSize) and start < 0) { + return op->emitError("unimplemented: not support dynamic dimSize when " + "start smaller than 0."); + } + start = start >= 0 ? start : dimSize + start; + + int64_t end; + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { + return op->emitError("unimplemented: end is not constant"); + } else if (ShapedType::isDynamic(dimSize) and end < 0) { + return op->emitError( + "unimplemented: not support dynamic dimSize when end smaller than 0."); + } + end = end >= 0 ? end : dimSize + end; + + int64_t size = 0; + std::vector indicesVec; + for (int64_t i = start; i < end; i += step) { + indicesVec.push_back(i); + ++size; + } + ArrayRef indices(indicesVec); + std::vector tmp_shape = {size, 1}; + ArrayRef shape(tmp_shape); + RankedTensorType constType = + RankedTensorType::get(shape, rewriter.getIntegerType(64)); + auto constAttr = DenseElementsAttr::get( + RankedTensorType::get(shape, rewriter.getIntegerType(64)), indices); + auto const_op = + rewriter.create(loc, constType, constAttr); + Value scatterIndices = const_op.getResult(); + + SmallVector updateWindowDims; + for (int64_t i = 0; i < inputType.getRank(); ++i) { + if (i == dim) { + continue; + } + updateWindowDims.push_back(i); } + auto scatterArgs = stablehlo::ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + /*updateWindowDims=*/updateWindowDims, + /*insertedWindowDims=*/{dim}, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, + /*scatterDimsToOperandDim=*/{dim}, + /*indexVectorDim=*/1); + Value src = adaptor.getSrc(); - auto srcType = cast(src.getType()); - int64_t srcRank = srcType.getRank(); - SmallVector srcAbstractSizes(srcRank, kUnknownSize); - auto abstractSrcType = RankedTensorType::get( - makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); - Value abstractSrc = - rewriter.create(loc, abstractSrcType, src); + auto scatterOp = rewriter.create( + loc, resultType, input, scatterIndices, src, scatterArgs, false, false); - Value result = rewriter.create( - loc, abstractSrc, input, offsets, resultShape, strides); + Block &block = scatterOp.getUpdateComputation().emplaceBlock(); + auto blockArgumentType = + RankedTensorType::get({}, inputType.getElementType()); + block.addArgument(blockArgumentType, loc); + block.addArgument(blockArgumentType, loc); - rewriter.replaceOpWithNewOp(op, resultType, result); + auto *lhs = block.args_begin(); + auto *rhs = std::next(lhs); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + rewriter.create(loc, *rhs); + } + + rewriter.replaceOp(op, scatterOp.getResults()); return success(); } @@ -670,24 +749,20 @@ class ConvertAtenScatterOp : public ConvertAtenOp { return rewriter.notifyMatchFailure(op, "invalid `dim` param detected"); } - auto options = this->getOptions(); - - auto indexShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + auto indexShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, index); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } - auto intType = rewriter.getIntegerType(options.dimSizeIndexBits); // slice src tensor to have the same shape bound of index tensor in the // leading dimensions. PyTorch has guaranteed that src tensor size will not // be smaller than that of index tensor. REF: // https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_ auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 0)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); SmallVector sliceIndicies(srcType.getRank(), zero); SmallVector sliceStrides(srcType.getRank(), one); @@ -745,6 +820,8 @@ class ConvertAtenScatterOp : public ConvertAtenOp { rewriter.getContext(), /*updateWindowDims=*/{}, /*insertedWindowDims=*/insertedWindowDims, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, /*indexVectorDim=*/indexVecDim); @@ -791,7 +868,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputTensorType = cast(input.getType()); auto outType = cast(getTypeConverter()->convertType(op.getType())); - auto outShape = outType.getShape(); Value indexList = op.getIndices(); SmallVector indicesTorchType; if (!getListConstructElements(indexList, indicesTorchType)) @@ -802,8 +878,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTorchType); int maxIndexRank = -1; - auto gatherIndicesInfo = broadcastAndConcatIndices(op, rewriter, indexTensors, - outShape, maxIndexRank); + auto gatherIndicesInfo = broadcastAndConcatIndices( + op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank); if (failed(gatherIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); @@ -827,6 +903,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); @@ -858,8 +936,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = cast(getTypeConverter()->convertType(op.getType())); auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); + auto inputShape = inputType.getShape(); + auto inputRank = inputType.getRank(); auto valuesType = cast(values.getType()); + int64_t valueRank = valuesType.getRank(); auto valuesShape = valuesType.getShape(); bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) { @@ -871,36 +951,82 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!getListConstructElements(indexList, indicesTorchType)) return op.emitError( "unimplemented: the tensor list is not from list construct"); + int64_t indexCnt = indicesTorchType.size(); auto indexTensors = getTypeConvertedValues(rewriter, loc, getTypeConverter(), indicesTorchType); int maxIndexRank = -1; auto scatterIndicesInfo = broadcastAndConcatIndices( - op, rewriter, indexTensors, valuesShape, maxIndexRank); + op, rewriter, indexTensors, options.dimSizeIndexBits, maxIndexRank); if (failed(scatterIndicesInfo)) { return rewriter.notifyMatchFailure( op, "failed to generate broadcasted indices"); } auto scatterIndices = *scatterIndicesInfo; + // broadcast `values` tensor to match expectedValuesShape. + SmallVector scatterIndicesDims; + for (int64_t i = 0; i < maxIndexRank; ++i) { + scatterIndicesDims.push_back(i); + } + auto expectedValuesShapeTensorInfo = + hlo::getDimSizesOfTensor(rewriter, op, scatterIndices, scatterIndicesDims, + options.dimSizeIndexBits); + if (failed(expectedValuesShapeTensorInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get shape of broadcasted indices"); + } + auto expectedValuesShapeTensors = *expectedValuesShapeTensorInfo; + SmallVector trailingInputDims; + for (int64_t i = indexCnt; i < inputRank; ++i) { + trailingInputDims.push_back(i); + } + auto trailingInputShapeTensorInfo = hlo::getDimSizesOfTensor( + rewriter, op, input, trailingInputDims, options.dimSizeIndexBits); + if (failed(trailingInputShapeTensorInfo)) { + return rewriter.notifyMatchFailure(op, "failed to get shape of input"); + } + expectedValuesShapeTensors.append((*trailingInputShapeTensorInfo).begin(), + (*trailingInputShapeTensorInfo).end()); + + llvm::ArrayRef scatterIndicesShape = + (cast(scatterIndices.getType())).getShape(); + SmallVector expectedValuesShape( + scatterIndicesShape.begin(), scatterIndicesShape.begin() + maxIndexRank); + for (int64_t i = indexCnt; i < inputRank; i++) { + expectedValuesShape.push_back(inputShape[i]); + } + + valuesType = + RankedTensorType::get(expectedValuesShape, valuesType.getElementType()); + values = + hlo::promoteAndBroadcast(rewriter, values, valuesType, + rewriter + .create( + op->getLoc(), expectedValuesShapeTensors) + .getResult()); + valueRank = valuesType.getRank(); + valuesShape = valuesType.getShape(); + // create stablehlo::ScatterOp int64_t indexVecDim = maxIndexRank; SmallVector scatterDimOperandDimMap; SmallVector insertedWindowDims; SmallVector updateWindowDims; - for (int64_t i = 0; i < maxIndexRank; ++i) { + for (int64_t i = 0; i < indexCnt; ++i) { scatterDimOperandDimMap.push_back(i); insertedWindowDims.push_back(i); } - for (int64_t i = maxIndexRank; i < inputRank; ++i) { + for (int64_t i = maxIndexRank; i < valueRank; ++i) { updateWindowDims.push_back(i); } - llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n"; auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), /*updateWindowDims=*/updateWindowDims, /*insertedWindowDims=*/insertedWindowDims, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, /*indexVectorDim=*/indexVecDim); @@ -935,6 +1061,417 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenGridSamplerOp +// See +// https://github.com/pytorch/pytorch/blob/ec58f1f74ebcec744d2ab90ad34abd09c1018e92/torch/_decomp/decompositions.py#L3923-L4086 +namespace { +template +static Value getConstantLike(OpBuilder &b, Location loc, T constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + auto getAttr = [&]() -> Attribute { + if (isa(ty)) + return b.getIntegerAttr(ty, constant); + if (isa(ty)) + return b.getFloatAttr(ty, constant); + if (auto complexTy = dyn_cast(ty)) + return complex::NumberAttr::get(complexTy, constant, 0); + llvm_unreachable("unhandled element type"); + }; + return b.create(loc, cast(getAttr()), + val); +} + +template +static Value getConstTensor(ConversionPatternRewriter &rewriter, Operation *op, + ArrayRef values, ArrayRef shape, + Type ty) { + Location loc = op->getLoc(); + RankedTensorType valueType = RankedTensorType::get(shape, ty); + auto valueAttr = DenseElementsAttr::get(valueType, values); + return rewriter.create(loc, valueType, valueAttr); +} + +template +static Value getConstScalarTensor(ConversionPatternRewriter &rewriter, + Operation *op, T value, Type ty) { + return getConstTensor(rewriter, op, ArrayRef{value}, {}, ty); +} + +// Helper function to lower AtenGridSamplerOp. +static Value unnormalize(ConversionPatternRewriter &rewriter, Operation *op, + Value coords, int64_t size, Type elemTy, + bool alignCorners) { + Location loc = op->getLoc(); + APFloat pointFive(cast(elemTy).getFloatSemantics(), "0.5"); + APFloat sizeFloat = + APFloat(cast(elemTy).getFloatSemantics(), size); + APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); + APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); + + // double mul = alignCorners ? (size * 0.5 - 0.5) : (size * 0.5); + // double ofs = size * 0.5 - 0.5; + APFloat mul = + alignCorners ? sizeFloat * pointFive - pointFive : sizeFloat * pointFive; + APFloat ofs = sizeFloat * pointFive - pointFive; + Value constMul = getConstScalarTensor(rewriter, op, mul, elemTy); + Value constOfs = getConstScalarTensor(rewriter, op, ofs, elemTy); + + // use chlo::BroadcastMulOp to multiply constMul with coords. + DenseI64ArrayAttr bcastDimensions; + Value mulResult = rewriter.create(loc, coords, constMul, + bcastDimensions); + // use chlo::BroadcastAddOp to add constOfs to mulResult. + Value result = rewriter.create(loc, mulResult, constOfs, + bcastDimensions); + return result; +} + +static Value computeCoordinates(ConversionPatternRewriter &rewriter, + Operation *op, Value coords, int64_t size, + Type elemTy, int64_t padding_mode) { + // TODO: add support for padding_mode 1 and 2. + return coords; +} + +static Value computeSourceIndex(ConversionPatternRewriter &rewriter, + Operation *op, Value coords, int64_t size, + Type elemTy, int64_t padding_mode, + bool alignCorners) { + Value coordsUn = + unnormalize(rewriter, op, coords, size, elemTy, alignCorners); + return computeCoordinates(rewriter, op, coordsUn, size, elemTy, padding_mode); +} + +// def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor: +// return torch.logical_and( +// 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys +// < iH)) +// ) +static Value inBoundsCond(ConversionPatternRewriter &rewriter, Operation *op, + Value xs, Value ys, int64_t ih, int64_t iw, + Type elemTy) { + Location loc = op->getLoc(); + APFloat zeroFloat = + APFloat(cast(elemTy).getFloatSemantics(), 0); + Value zero = getConstScalarTensor(rewriter, op, zeroFloat, elemTy); + APFloat iwFloat = + APFloat(cast(elemTy).getFloatSemantics(), iw); + APFloat ihFloat = + APFloat(cast(elemTy).getFloatSemantics(), ih); + + Value iwFloatValue = getConstScalarTensor(rewriter, op, iwFloat, elemTy); + Value ihFloatValue = getConstScalarTensor(rewriter, op, ihFloat, elemTy); + + chlo::ComparisonTypeAttr compareTypeAttr = chlo::ComparisonTypeAttr::get( + rewriter.getContext(), chlo::ComparisonType::FLOAT); + chlo::ComparisonDirectionAttr compareLTAttr = + chlo::ComparisonDirectionAttr::get(rewriter.getContext(), + chlo::ComparisonDirection::LT); + chlo::ComparisonDirectionAttr compareGEAttr = + chlo::ComparisonDirectionAttr::get(rewriter.getContext(), + chlo::ComparisonDirection::GE); + DenseI64ArrayAttr bcastDimensions; + Value cond1 = rewriter.create( + loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond2 = rewriter.create( + loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); + Value cond3 = rewriter.create( + loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond4 = rewriter.create( + loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); + Value cond5 = + rewriter.create(loc, cond1, cond2, bcastDimensions); + Value cond6 = + rewriter.create(loc, cond3, cond4, bcastDimensions); + return rewriter.create(loc, cond5, cond6, + bcastDimensions); +} +// def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: +// cond = in_bounds_cond(xs, ys) +// # To clip to inside valid coordinates, we map the coordinates +// # to (x, y) = (0, 0) and also set the weight to 0 +// # We also change the shape of the tensor to the appropriate one for +// # broadcasting with N_idx, C_idx for the purposes of advanced +// indexing c = C if _expand_grid else 1 +// return tuple( +// torch.where(cond, t, 0).view(N, c, oH, oW) +// for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws) +// ) +SmallVector clip(ConversionPatternRewriter &rewriter, Operation *op, + Value xs, Value ys, Value ws, int64_t N, int64_t oH, + int64_t oW, int64_t iH, int64_t iW, Type elemTy) { + Location loc = op->getLoc(); + auto indexElemTy = rewriter.getI64Type(); + auto indexTy = RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); + + Value zeroIntValue = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, ArrayRef{0})); + + APFloat zeroAPFloat = + APFloat(cast(elemTy).getFloatSemantics(), 0); + Value zeroFloatValue = + getConstScalarTensor(rewriter, op, zeroAPFloat, elemTy); + Value cond = inBoundsCond(rewriter, op, xs, ys, iH, iW, elemTy); + Value xsInt = rewriter.create(loc, xs, indexElemTy); + Value ysInt = rewriter.create(loc, ys, indexElemTy); + + Value selectXs = rewriter.create( + loc, ArrayRef{cond, xsInt, zeroIntValue}); + Value selectYs = rewriter.create( + loc, ArrayRef{cond, ysInt, zeroIntValue}); + Value selectWs = rewriter.create( + loc, ArrayRef{cond, ws, zeroFloatValue}); + + SmallVector sizes = {N, 1, oH, oW}; + Value reshapedXs = rewriter.create( + loc, RankedTensorType::get(sizes, indexElemTy), selectXs); + Value reshapedYs = rewriter.create( + loc, RankedTensorType::get(sizes, indexElemTy), selectYs); + Value reshapedWs = rewriter.create( + loc, RankedTensorType::get(sizes, elemTy), selectWs); + return SmallVector{reshapedXs, reshapedYs, reshapedWs}; +} + +Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, + Value input, Value ix, Value iy, Value w, int64_t N, + int64_t oH, int64_t oW, int64_t iH, int64_t iW, Value Nidx, + Value CIdx, RankedTensorType outType, Type elemTy, + size_t dimSizeIndexBits) { + Location loc = op->getLoc(); + auto inputTensorType = cast(input.getType()); + SmallVector clipValues = + clip(rewriter, op, ix, iy, w, N, oH, oW, iH, iW, elemTy); + Value idxX = clipValues[0]; + Value idxY = clipValues[1]; + Value idxW = clipValues[2]; + SmallVector indexTensors{Nidx, CIdx, idxY, idxX}; + + int maxIndexRank = -1; + auto gatherIndicesInfo = + broadcastAndConcatIndices(input.getDefiningOp(), rewriter, indexTensors, + dimSizeIndexBits, maxIndexRank); + auto gatherIndices = *gatherIndicesInfo; + int64_t numIndicesDim = indexTensors.size(); + int64_t indexVecDim = maxIndexRank; + + SmallVector offsetDims; + SmallVector collapsedDims; + SmallVector startIndexMap; + for (int64_t i = 0; i < numIndicesDim; ++i) { + collapsedDims.push_back(i); + startIndexMap.push_back(i); + } + for (int64_t i = numIndicesDim; i < inputTensorType.getRank(); i++) { + offsetDims.push_back(i + maxIndexRank - numIndicesDim); + } + auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/offsetDims, + /*collapsedSliceDims=*/collapsedDims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, + /*startIndexMap=*/startIndexMap, + /*indexVecDim=*/indexVecDim); + + SmallVector sliceSizes; + auto inputShape = makeShapeTorchCompatible(inputTensorType.getShape()); + for (int64_t i = 0; i < inputTensorType.getRank(); ++i) { + if (i < numIndicesDim) { + sliceSizes.push_back(1); + } else { + sliceSizes.push_back(inputShape[i]); + } + } + + Value gather = rewriter.create( + loc, input, gatherIndices, dimsAttr, + rewriter.getDenseI64ArrayAttr(sliceSizes)); + // use chlo::BroadcastMulOp to multiply idxW with gather. + DenseI64ArrayAttr bcastDimensions; + return rewriter.create(loc, gather, idxW, + bcastDimensions); +} + +} // namespace +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGridSamplerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = op->getLoc(); + Value input = adaptor.getInput(); + Value grid = adaptor.getGrid(); + + int64_t interpolationMode; + if (!matchPattern(op.getInterpolationMode(), + m_TorchConstantInt(&interpolationMode))) + return rewriter.notifyMatchFailure( + op, "interpolation_mode must be an integer constant"); + int64_t paddingMode; + if (!matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingMode))) + return rewriter.notifyMatchFailure( + op, "padding_mode must be an integer constant"); + + if (interpolationMode != 0 && interpolationMode != 1) + return rewriter.notifyMatchFailure( + op, "only support interpolation_mode = 0 (bilinear) or 1(nearest)"); + + if (paddingMode != 0) + return rewriter.notifyMatchFailure(op, + "only support paddingMode = 0 (Zero)"); + + bool alignCorners = false; + if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners))) + return rewriter.notifyMatchFailure( + op, "alignCorners must be a boolean constant"); + + RankedTensorType inputTy = cast(input.getType()); + RankedTensorType gridTy = cast(grid.getType()); + RankedTensorType outTy = + cast(getTypeConverter()->convertType(op.getType())); + Type elemTy = inputTy.getElementType(); + if (inputTy.getRank() != 4) + return rewriter.notifyMatchFailure(op, "input must be a 4D tensor"); + if (gridTy.getRank() != 4) + return rewriter.notifyMatchFailure(op, "grid must be a 4D tensor"); + + auto inputSize = inputTy.getShape(); + auto gridSize = gridTy.getShape(); + int64_t N = inputSize[0]; + int64_t C = inputSize[1]; + int64_t iH = inputSize[2]; + int64_t iW = inputSize[3]; + int64_t oH = gridSize[1]; + int64_t oW = gridSize[2]; + // grid is a 4D tensor with shape (N, oH, oW, 2) + + Type indexElemTy = rewriter.getI64Type(); + RankedTensorType indexTy = + RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); + Value constN = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, {N})); + Value constC = rewriter.create( + loc, indexTy, DenseIntElementsAttr::get(indexTy, {C})); + APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); + APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); + + Value constOneFloat = getConstScalarTensor(rewriter, op, one, elemTy); + + auto NidxFlatten = rewriter.create( + loc, RankedTensorType::get(mlir::ArrayRef{N}, indexElemTy), + constN, 0); + auto CidxFlatten = rewriter.create( + loc, RankedTensorType::get(mlir::ArrayRef{C}, indexElemTy), + constC, 0); + + // Reshape NidxFlatten to 4D tensor (N, 1, 1, 1) + auto NidxSizes = mlir::SmallVector{N, 1, 1, 1}; + auto Nidx = rewriter.create( + loc, RankedTensorType::get(NidxSizes, indexElemTy), NidxFlatten); + + // Reshape CidxFlatten to 4D tensor (1, C, 1, 1) + auto CidxSizes = mlir::SmallVector{1, C, 1, 1}; + auto Cidx = rewriter.create( + loc, RankedTensorType::get(CidxSizes, indexElemTy), CidxFlatten); + + llvm::SmallVector stride(4, 1); + auto gridX = rewriter.create( + loc, + RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, + gridTy.getElementType()), + grid, mlir::SmallVector{0, 0, 0, 0}, + mlir::SmallVector{N, oH, oW, 1}, stride); + auto gridY = rewriter.create( + loc, + RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, + gridTy.getElementType()), + grid, mlir::SmallVector{0, 0, 0, 1}, + mlir::SmallVector{N, oH, oW, 2}, stride); + // squeeze last dimension + auto gridXshape = mlir::SmallVector{N, oH, oW}; + + auto gridXReshape = rewriter.create( + loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridX); + auto gridYReshape = rewriter.create( + loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridY); + + if (interpolationMode == 0) { + Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, + paddingMode, alignCorners); + Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, + paddingMode, alignCorners); + Value ix_nw = rewriter.create(loc, ix); + Value iy_nw = rewriter.create(loc, iy); + + DenseI64ArrayAttr bcastDimensions; + Value ix_ne = rewriter.create( + loc, ix_nw, constOneFloat, bcastDimensions); + Value iy_ne = iy_nw; + Value ix_sw = ix_nw; + Value iy_sw = rewriter.create( + loc, iy_nw, constOneFloat, bcastDimensions); + Value ix_se = ix_ne; + Value iy_se = iy_sw; + + // w_nw = (ix_se - ix) * (iy_se - iy) + // w_ne = (ix - ix_sw) * (iy_sw - iy) + // w_sw = (ix_ne - ix) * (iy - iy_ne) + // w_se = (ix - ix_nw) * (iy - iy_nw) + Value w_nw = rewriter.create( + loc, + rewriter.create(loc, ix_se, ix, bcastDimensions), + rewriter.create(loc, iy_se, iy, bcastDimensions), + bcastDimensions); + Value w_ne = rewriter.create( + loc, + rewriter.create(loc, ix, ix_sw, bcastDimensions), + rewriter.create(loc, iy_sw, iy, bcastDimensions), + bcastDimensions); + Value w_sw = rewriter.create( + loc, + rewriter.create(loc, ix_ne, ix, bcastDimensions), + rewriter.create(loc, iy, iy_ne, bcastDimensions), + bcastDimensions); + Value w_se = rewriter.create( + loc, + rewriter.create(loc, ix, ix_nw, bcastDimensions), + rewriter.create(loc, iy, iy_nw, bcastDimensions), + bcastDimensions); + + Value summand_nw = + getSummand(rewriter, op, input, ix_nw, iy_nw, w_nw, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + Value summand_ne = + getSummand(rewriter, op, input, ix_ne, iy_ne, w_ne, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + Value summand_sw = + getSummand(rewriter, op, input, ix_sw, iy_sw, w_sw, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + Value summand_se = + getSummand(rewriter, op, input, ix_se, iy_se, w_se, N, oH, oW, iH, iW, + Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); + + // summand_nw + summand_ne + summand_sw + summand_se + Value sum = rewriter.create(loc, summand_nw, summand_ne); + sum = rewriter.create(loc, sum, summand_sw); + sum = rewriter.create(loc, sum, summand_se); + rewriter.replaceOp(op, sum); + } else if (interpolationMode == 1) { + Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, + paddingMode, alignCorners); + Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, + paddingMode, alignCorners); + Value ix_round = rewriter.create(loc, ix); + Value iy_round = rewriter.create(loc, iy); + Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round); + Value summand = getSummand(rewriter, op, input, ix_round, iy_round, + oneTensor, N, oH, oW, iH, iW, Nidx, Cidx, outTy, + elemTy, options.dimSizeIndexBits); + rewriter.replaceOp(op, summand); + } + return success(); +} + void mlir::torch::torch_to_stablehlo:: populateGatherScatterOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -951,6 +1488,7 @@ void mlir::torch::torch_to_stablehlo:: INSERT_ATENOP_PATTERN(AtenSliceScatterOp); INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenGridSamplerOp); #undef INSERT_ATENOP_PATTERN #define INSERT_ATEN_SCATTER_PATTERN(AtenOp, reduceType) \ diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index 70028cd2df49..b42ed7cc7722 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -18,7 +18,6 @@ #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -149,10 +148,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(rhsShape.begin(), rhsShape.begin() + leadingRank); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); - auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims, - dimSizeIndexBits); - auto lhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); + auto newDimSizes = + *hlo::getDimIndexOfTensor(rewriter, op, rhs, leadingDims); + auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs); newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), lhsDimSizes.end()); lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, @@ -161,10 +159,9 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(lhsShape.begin(), lhsShape.begin() + leadingRank); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); - auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims, - dimSizeIndexBits); - auto rhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + auto newDimSizes = + *hlo::getDimIndexOfTensor(rewriter, op, lhs, leadingDims); + auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs); newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), rhsDimSizes.end()); rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, @@ -208,10 +205,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, return; } - auto lhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); - auto rhsDimSizes = - *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + auto lhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, lhs); + auto rhsDimSizes = *hlo::getDimIndexOfTensor(rewriter, op, rhs); if (!lhsBroadcastDims.empty()) { SmallVector lhsNewShape(newBatchShape); @@ -330,7 +325,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { lhsContractingDim, rhsContractingDim); output = rewriter .create(op->getLoc(), outTy, lhs, rhs, - dotDimensionNumbers, nullptr) + dotDimensionNumbers, nullptr, + nullptr) .getResult(); return success(); } @@ -351,9 +347,9 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { rewriter.replaceOpWithNewOp( op, - ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(), + cast( + ConvertAtenOp::getTypeConverter()->convertType( + op.getType())), output); return success(); @@ -499,7 +495,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); Value matmulOutput = rewriter.create( - op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); + op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr); Value matmulPlusBias = matmulOutput; if (!isa(biasTy)) { @@ -527,16 +523,15 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto weightTy = cast(weight.getType()); auto weightElemTy = weightTy.getElementType(); auto rank = weightTy.getRank(); - const auto &options = getOptions(); - SmallVector weightShapeVec = *hlo::getDimSizesOfTensor( - rewriter, op, weight, options.dimSizeIndexBits); + SmallVector weightShapeVec = + *hlo::getDimIndexOfTensor(rewriter, op, weight); auto weightShape = weightTy.getShape(); SmallVector weightShapeInt(rank); std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); // 1. [H, W, ..., OC, IC] => [H, W, ..., OC, G, IC//G] Value GValue = rewriter.create( - op->getLoc(), rewriter.getI64IntegerAttr(groups)); + op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), groups)); Value ICDivGValue = rewriter.create( op->getLoc(), weightShapeVec[rank - 1], GValue); Value OCMulGValue = rewriter.create( @@ -592,25 +587,32 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto weightShape = weightTy.getShape(); auto nDims = inputTy.getRank(); + auto weightDims = weightTy.getRank(); + auto kernelDims = weightDims - 2; + auto nSpatialDims = nDims - 2; auto convOutTy = outType; // Transpose weight SmallVector perm(nDims); SmallVector transposeShape(nDims); - for (int i = 0; i < nDims; i++) { - if (i < 2) - perm[i] = nDims - 2 + i; + // 1d: kernelDims = 1, [0, 1, 2] => [2, 1, 0] + // 2d: kernelDims = 2, [0, 1, 2, 3] => [2, 3, 1, 0] + // 3d: kernelDims = 3, [0, 1, 2, 3, 4] => [2, 3, 4, 1, 0] + for (int i = 0; i < weightDims; i++) { + if (i < kernelDims) + perm[i] = 2 + i; else - perm[i] = nDims - i - 1; + perm[i] = kernelDims + 1 - i; transposeShape[i] = weightShape[perm[i]]; } + auto reverseDim = llvm::to_vector<4>(llvm::seq(0, kernelDims)); auto transposeTy = RankedTensorType::get(transposeShape, weightTy.getElementType()); auto transposeOp = rewriter.create( op->getLoc(), transposeTy, weight, perm); auto reverseOp = rewriter.create( - op->getLoc(), transposeOp, ArrayRef{0, 1}); + op->getLoc(), transposeOp, reverseDim); // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); @@ -731,9 +733,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { // If transposed is set to true, // the weight shape changes to [IC, (OC//G), KH, KW] auto weightTy = cast(weight.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outTy = + cast(getTypeConverter()->convertType(op.getType())); if (!inputTy || !weightTy || !outTy) { return op.emitError("input, weight and output must be ranked tensors"); } @@ -834,10 +835,9 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto inputUnsqzDims = llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); - const auto &options = getOptions(); - bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, - options.dimSizeIndexBits); - bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); + bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims); + bias = + hlo::promoteType(rewriter, op.getLoc(), bias, outTy.getElementType()); DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( @@ -863,6 +863,7 @@ void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality( patterns.add>(typeConverter, context, options) INSERT_MM_ATENOP_PATTERN(AtenMmOp); INSERT_MM_ATENOP_PATTERN(AtenBmmOp); + INSERT_MM_ATENOP_PATTERN(Aten_IntMmOp); #undef INSERT_MM_ATEMOP_PATTERN #define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 9219b4af355f..8ad5cefc0bf3 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -18,12 +18,9 @@ #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include #include using namespace mlir; @@ -55,7 +52,7 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, // Max pooling if (isa(op)) { + AtenMaxPool1dWithIndicesOp, AtenMaxPool2dWithIndicesOp>(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -76,6 +73,161 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, return nullptr; } +// AtenMaxPool1dWithIndicesOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenMaxPool1dWithIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = cast(input.getType()); + auto inputElemTy = inputTy.getElementType(); + auto inputShape = inputTy.getShape(); + auto inputRank = inputTy.getRank(); + + auto outValTy = + cast(getTypeConverter()->convertType(op.getType(0))); + auto outIdxTy = + cast(getTypeConverter()->convertType(op.getType(1))); + + if (inputRank <= 1) { + return op.emitError( + "max_pooling1d only supports inputs with rank higher than 1"); + } + + SmallVector padding, kernelSize, stride, dilation; + bool ceilMode = false; + + if (!(matchPattern(op.getKernelSize(), + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); + } + if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { + return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + } + if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); + } + if (!(matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation)))) { + return rewriter.notifyMatchFailure(op, + "non-const int dilation unsupported!"); + } + if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { + return rewriter.notifyMatchFailure(op, + "non-const bool ceil_mode unsupported!"); + } + + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - 1); + std::copy(dilation.begin(), dilation.end(), + stablehloDilation.begin() + inputRank - 1); + std::copy(kernelSize.begin(), kernelSize.end(), + stablehloKernelSize.begin() + inputRank - 1); + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + + Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); + DenseIntElementsAttr pad = DenseIntElementsAttr::get( + RankedTensorType::get( + {static_cast(inputRank), static_cast(2)}, + rewriter.getI64Type()), + stablehloPadding); + DenseI64ArrayAttr baseDilations; + + auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + auto inputShapeTensor = rewriter.create( + op->getLoc(), inputShapeVec); + + // no need to reshape here for max_pool_1d. Need to make sure the iota + // dimension. dim=inputRank-2 or dim=inputRank-1? + auto indexTensor = + rewriter + .create( + op->getLoc(), + RankedTensorType::get(inputShape, rewriter.getI64Type()), + inputShapeTensor, static_cast(inputRank - 1)) + .getResult(); + Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); + + auto reduceWindowOp = rewriter.create( + op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, + mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, + windowDimensions, windowStrides, baseDilations, windowDilations, pad); + + // add block. + Block &block = reduceWindowOp.getBody().emplaceBlock(); + auto blockValArgumentType = RankedTensorType::get({}, inputElemTy); + auto blockIdxArgumentType = RankedTensorType::get({}, rewriter.getI64Type()); + auto compareResultType = RankedTensorType::get({}, rewriter.getI1Type()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + block.addArgument(blockValArgumentType, op->getLoc()); + block.addArgument(blockIdxArgumentType, op->getLoc()); + auto *firstValArg = block.args_begin(); + auto *firstIdxArg = std::next(firstValArg); + auto *secondValArg = std::next(firstIdxArg); + auto *secondIdxArg = std::next(secondValArg); + + stablehlo::ComparisonTypeAttr compareTypeAttr; + if (isa(inputTy.getElementType())) { + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::FLOAT); + } else if (isa(inputTy.getElementType())) { + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::SIGNED); + } + + stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::EQ); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + + Value compareGeResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + Value retValResult = rewriter.create( + op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + + // Get smaller index if compared values are equal. + Value compareEqResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareEqDirectionAttr, compareTypeAttr); + Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, + *secondIdxArg); + Value idxWithGeVal = rewriter.create( + op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = rewriter.create( + op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + rewriter.create( + op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + } + + rewriter.replaceOp(op, reduceWindowOp.getResults()); + return success(); +} + // AtenMaxPool2dWithIndicesOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -149,9 +301,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI64Type()), stablehloPadding); - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -219,10 +369,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto *secondIdxArg = std::next(secondValArg); stablehlo::ComparisonTypeAttr compareTypeAttr; - if (inputTy.getElementType().isa()) { + if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::FLOAT); - } else if (inputTy.getElementType().isa()) { + } else if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } @@ -398,9 +548,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { RankedTensorType inputTy = cast(input.getType()); Type inputElemTy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); - RankedTensorType outTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + RankedTensorType outTy = cast( + ConvertAtenOp::getTypeConverter()->convertType(op.getType())); auto outShape = outTy.getShape(); if (inputRank <= Dim) { @@ -528,7 +677,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { } else { assert(false && "Unsupported pooling dimension"); } - divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); + divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, + outTy.getElementType()); DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); @@ -538,11 +688,9 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { // Use another mhlo.ReduceWindowOp to get the divisor Value windowSizeConst = hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); - windowSizeConst = - hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); - const auto &options = ConvertAtenOp::getOptions(); - auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, - options.dimSizeIndexBits); + windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, + outTy.getElementType()); + auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); @@ -591,7 +739,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputTy = cast(input.getType()); auto outTy = cast(getTypeConverter()->convertType(op.getType())); - input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + input = + hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType()); inputTy = cast(input.getType()); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); @@ -663,6 +812,7 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( #define INSERT_ATEN_POOLING_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) + INSERT_ATEN_POOLING_PATTERN(AtenMaxPool1dWithIndicesOp); INSERT_ATEN_POOLING_PATTERN(AtenMaxPool2dWithIndicesOp); INSERT_ATEN_POOLING_PATTERN(AtenCumsumOp); #undef INSERT_ATEN_POOLING_PATTERN diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 81a1a1f564d1..bca69906d5ad 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -18,11 +18,9 @@ #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include #include @@ -32,6 +30,18 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::torch_to_stablehlo; +static SmallVector getReduceOutputShape(ArrayRef inputShape, + ArrayRef dims) { + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (size_t i = 0; i < inputShape.size(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputShape[i]); + } + } + return reduceResultShape; +} + static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); @@ -44,8 +54,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); return rewriter.create(op->getLoc(), constType, @@ -53,7 +62,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -61,8 +70,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/true)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); @@ -71,7 +79,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -79,8 +87,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, /*negative=*/false)}); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())}); @@ -95,8 +102,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, constAttr); - } else if (isa(elementTy) && - elementTy.getIntOrFloatBitWidth() != 8) { + } else if (isa(elementTy)) { APInt one(elementTy.getIntOrFloatBitWidth(), 1); auto constAttr = DenseElementsAttr::get(constType, one); return rewriter.create(op->getLoc(), constType, @@ -104,14 +110,16 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { - auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 1)}); + if (isa(op)) { + auto constAttr = + DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 1)}); return rewriter.create(op->getLoc(), constType, constAttr); } - if (isa(op)) { - auto constAttr = DenseElementsAttr::get(constType, {APInt(1, 0)}); + if (isa(op)) { + auto constAttr = + DenseElementsAttr::get(constType, {APInt(/*numBits=*/1, 0)}); return rewriter.create(op->getLoc(), constType, constAttr); } @@ -121,11 +129,67 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, return nullptr; } -// Util for converting AtenArgmaxOp and AtenMaxDimOp +static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, + Type outTy, + ArrayRef dims, + PatternRewriter &rewriter) { + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) + return nullptr; + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return nullptr; + + stablehlo::ReduceOp reduce = rewriter.create( + op->getLoc(), outTy, input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = reduce.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value result; + if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else { + op->emitError("unimplemented lowering in " + "createReduceOpWithSingleRegionOp"); + return nullptr; + } + rewriter.create(op->getLoc(), result); + } + return reduce.getResults()[0]; +} + +// Util for converting AtenMaxDimOp/AtenMinDimOp static std::optional -getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, - ArrayRef inputShapeVec, int64_t dim, - size_t dimSizeIndexBits) { +createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op, + Value &input, ArrayRef inputShapeVec, + int64_t dim, size_t dimSizeIndexBits) { auto inputTy = cast(input.getType()); if (!inputTy) { return std::nullopt; @@ -146,8 +210,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } - std::vector outputShape(inputShape.begin(), inputShape.end()); - outputShape.erase(outputShape.begin() + dim); + auto outputShape = getReduceOutputShape(inputShape, {dim}); auto outputTy = RankedTensorType::get(outputShape, inputElemTy); auto outputIndexTy = RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); @@ -189,16 +252,19 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto *secondIdxArg = std::next(secondValArg); stablehlo::ComparisonTypeAttr compareTypeAttr; - if (inputTy.getElementType().isa()) { + if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::FLOAT); - } else if (inputTy.getElementType().isa()) { + } else if (isa(inputTy.getElementType())) { compareTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareLeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::LE); stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::EQ); @@ -207,11 +273,21 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, - compareGeDirectionAttr, compareTypeAttr); + Value compareResult; + if (isa(op)) { + compareResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareGeDirectionAttr, compareTypeAttr); + } else if (isa(op)) { + compareResult = rewriter.create( + op->getLoc(), compareResultType, *firstValArg, *secondValArg, + compareLeDirectionAttr, compareTypeAttr); + } else { + op->emitError("unimplement lowering of createReduceOpReturnIndices"); + return std::nullopt; + } Value retValResult = rewriter.create( - op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + op->getLoc(), compareResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. Value compareEqResult = rewriter.create( @@ -220,16 +296,33 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); Value idxWithGeVal = rewriter.create( - op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg); Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); rewriter.create( - op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + op->getLoc(), ValueRange{retValResult, retIdxResult}); } return stablehloReduceOp.getResults(); } +static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter, + Location loc, Value reduceResult, + ArrayRef inputShapeVec, + Type outType, + ArrayRef dims) { + SmallVector outShapeVec(inputShapeVec); + Value one = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + for (auto dim : dims) { + outShapeVec[dim] = one; + } + auto outShapeTensor = + rewriter.create(loc, outShapeVec); + return rewriter.create( + loc, outType, reduceResult, outShapeTensor); +} + namespace { template class ConvertAtenReductionOp : public ConvertAtenOp { @@ -238,836 +331,443 @@ class ConvertAtenReductionOp : public ConvertAtenOp { using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; + ConversionPatternRewriter &rewriter) const override { + assert(false && "Unimplemented"); + return failure(); + }; }; -} // namespace -// AtenArgmaxOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenArgmaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } +template +class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported! - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenArgmaxOp to StableHLO"); - } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); + } - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); - } - dim = toPositiveDim(dim, inputTy.getRank()); - if (!isValidDim(dim, inputTy.getRank())) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + SmallVector dims = + llvm::to_vector(llvm::seq(0, inputTy.getRank())); + Value result = + createReduceOpWithSingleRegionOp(op, input, outTy, dims, rewriter); + if (!result) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + rewriter.replaceOp(op, result); + return success(); } +}; - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } +template +class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(inputShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), {dim}); + + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), {dim}, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + if (keepDim) { + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, {dim}); + } + rewriter.replaceOp(op, reduceResult); + return success(); } - auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); +}; - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], - outShapeTensor); +template +class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + auto outTy = dyn_cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getType())); + if (!inputTy || !outTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + if (inputElemTy != outTy.getElementType()) { + // use output type as computation type + input = rewriter.create(op->getLoc(), input, + outTy.getElementType()); + } + + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + SmallVector inputDims; + SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + if (inputDims.size() == 0) { + dims = llvm::to_vector(llvm::seq(0, inputTy.getRank())); + } else { + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } + } + llvm::sort(dims.begin(), dims.end()); + } + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); + + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + if (keepDim) { + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims); + } + rewriter.replaceOp(op, reduceResult); return success(); } +}; - rewriter.replaceOp(op, stablehloReduceResults[1]); - return success(); -} +template +class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp { +public: + using ConvertAtenReductionOp::ConvertAtenReductionOp; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "Only floating-point or integer datatype legalization supported"); + } + + RankedTensorType valResultType = cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getResult(0).getType())); + RankedTensorType idxResultType = cast( + ConvertAtenReductionOp::getTypeConverter()->convertType( + op.getResult(1).getType())); + Type idxElementType = idxResultType.getElementType(); + if (!isa(idxElementType)) { + return op.emitError("indices result should to be integer tyep"); + } + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); + } + dim = toPositiveDim(dim, inputTy.getRank()); + if (!isValidDim(dim, inputTy.getRank())) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + const auto &options = ConvertAtenReductionOp::getOptions(); + auto inputShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(inputShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto inputShapeVec = *inputShapeInfo; + + if (op.getResult(1).use_empty()) { + llvm::SmallVector outputShape(inputTy.getShape()); + outputShape.erase(outputShape.begin() + dim); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get(outputShape, inputElemTy), + ArrayRef{dim}, rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); + } + + if (keepDim) { + reduceResult = + reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), reduceResult, + inputShapeVec, valResultType, {dim}); + } + rewriter.replaceOp(op, {reduceResult, Value()}); + return success(); + } else { + ValueRange stablehloReduceResults = + createReduceOpReturnIndices(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); + SmallVector reduceResults(stablehloReduceResults); + if (keepDim) { + reduceResults[0] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResults[0], inputShapeVec, + valResultType, {dim}); + reduceResults[1] = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResults[1], inputShapeVec, + idxResultType, {dim}); + } + rewriter.replaceOp(op, reduceResults); + return success(); + } + }; +}; } // namespace -// AtenMaxDimOp +// AtenSumDimIntListOp namespace { template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMaxDimOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenSumDimIntListOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); if (!inputTy) { return rewriter.notifyMatchFailure( op, "only Tensor types supported in StableHLO"); } + if (inputTy.getElementType() != outTy.getElementType()) { + // Use output element type as computation type. + auto dstElemTy = outTy.getElementType(); + input = + rewriter.create(op->getLoc(), input, dstElemTy); + inputTy = dyn_cast(input.getType()); + } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { return op.emitError( "Only floating-point or integer datatype legalization supported"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxDimOp to StableHLO"); - } - - RankedTensorType valResultType = getTypeConverter() - ->convertType(op.getResult(0).getType()) - .template cast(); - RankedTensorType idxResultType = getTypeConverter() - ->convertType(op.getResult(1).getType()) - .template cast(); - Type idxElementType = idxResultType.getElementType(); - if (!isa(idxElementType)) { - return op.emitError("Aten.max.dim needs integer-like result"); - } - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure(op, "non-int dim unsupported"); + SmallVector inputDims; + SmallVector dims; + if (failed(checkNotNone(rewriter, op, op.getDim()))) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } else { + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } } - dim = toPositiveDim(dim, inputTy.getRank()); - if (!isValidDim(dim, inputTy.getRank())) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } } + llvm::sort(dims.begin(), dims.end()); + + SmallVector reduceResultShape = + getReduceOutputShape(inputTy.getShape(), dims); + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); } - const auto &options = getOptions(); - auto inputShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(inputShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, + RankedTensorType::get(reduceResultShape, outTy.getElementType()), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } - auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, stablehloReduceResults[0], - outShapeTensor); - auto stablehloReduceIndexResult = - rewriter.create( - op->getLoc(), idxResultType, stablehloReduceResults[1], - outShapeTensor); - rewriter.replaceOp( - op, {stablehloReduceValueResult, stablehloReduceIndexResult}); - return success(); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + reduceResult = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), reduceResult, *outShapeInfo, outTy, dims); } - - rewriter.replaceOp(op, - {stablehloReduceResults[0], stablehloReduceResults[1]}); + rewriter.replaceOp(op, reduceResult); return success(); } } // namespace -// AtenSumOp +// AtenFrobeniusNormDimOp +// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given +// dims) + stablehlo.sqrt namespace { template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenSumOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenFrobeniusNormDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - if (inputTy.getElementType() != outTy.getElementType()) { - // Use output element type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = dyn_cast(input.getType()); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { + auto inputType = dyn_cast(input.getType()); + if (!inputType) { return op.emitError( - "only floating-point or integer datatype legalization supported"); + "only ranked tensor input supported in AtenFrobeniusNormDimOp"); } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumOp to StableHLO"); + auto inputRank = inputType.getRank(); + auto inputElemType = inputType.getElementType(); + if (!isa(inputElemType)) { + return op.emitError( + "only float dtype allowed in input tensor of AtenFrobeniusNormDimOp"); } SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); + for (auto &dim : dims) { + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, + "invalid dimension detected in `dim`"); + } + } + // Sort the dims in ascending order, making the conversion + // stable with unordered dims. + std::sort(dims.begin(), dims.end()); - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + SmallVector reduceResultShape = + getReduceOutputShape(inputType.getShape(), dims); - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure( + op, "non-const bool `keepdim` is not supported"); + } - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); + auto squareOp = rewriter.create(op->getLoc(), input, input); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, squareOp.getResult(), + RankedTensorType::get(reduceResultShape, inputElemType), dims, rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); + Value output = rewriter.create(op->getLoc(), reduceResult); + + if (keepDim) { + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + output = reshapeReduceResultWhenKeepDim( + rewriter, op->getLoc(), output, *outShapeInfo, + getTypeConverter()->convertType(op.getType()), dims); + } + rewriter.replaceOp(op, output); return success(); } } // namespace -// AtenAllOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenAllOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenAllOp to StableHLO"); - } - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - - if (inputElemTy != outTy.getElementType()) { - // Use output bool type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); - inputElemTy = inputTy.getElementType(); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value allResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), allResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenAnyOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenAnyOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = input.getType().dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenAllOp to StableHLO"); - } - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - - if (inputElemTy != outTy.getElementType()) { - // Use output bool type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = input.getType().dyn_cast(); - inputElemTy = inputTy.getElementType(); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value anyResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), anyResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenProdOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenProdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - if (inputTy.getElementType() != outTy.getElementType()) { - // Use output element type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = dyn_cast(input.getType()); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenProdOp to StableHLO"); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value mulResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), mulResult); - } - - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); - - return success(); -} -} // namespace - -// AtenMaxOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMaxOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxOp to StableHLO"); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value maxResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), maxResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenMinOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "only floating-point or integer datatype legalization supported"); - } - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMinOp to StableHLO"); - } - - SmallVector dims; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); - } - - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value minResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), minResult); - } - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenSumDimIntListOp -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenSumDimIntListOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value input = adaptor.getSelf(); - auto inputTy = dyn_cast(input.getType()); - auto outTy = getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); - if (!inputTy) { - return rewriter.notifyMatchFailure( - op, "only Tensor types supported in StableHLO"); - } - if (inputTy.getElementType() != outTy.getElementType()) { - // Use output element type as computation type. - auto dstElemTy = outTy.getElementType(); - input = - rewriter.create(op->getLoc(), input, dstElemTy); - inputTy = dyn_cast(input.getType()); - } - auto inputElemTy = inputTy.getElementType(); - if (!inputElemTy.isIntOrFloat()) { - return op.emitError( - "Only floating-point or integer datatype legalization supported"); - } - - // Currently, (u)int8 dtype is not supported - if (isa(inputElemTy) && - inputElemTy.getIntOrFloatBitWidth() == 8) { - return rewriter.notifyMatchFailure( - op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumDimIntListOp to StableHLO"); - } - - SmallVector inputDims; - SmallVector dims; - - if (failed(checkNotNone(rewriter, op, op.getDim()))) { - inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); - } else { - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { - return rewriter.notifyMatchFailure( - op, "non-const integer `dim` is not supported"); - } - if (inputDims.size() == 0) { - inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); - } - } - - for (auto d : inputDims) { - d = toPositiveDim(d, inputTy.getRank()); - // Drop invalid dims - if (isValidDim(d, inputTy.getRank())) { - dims.push_back(d); - } - } - - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputTy.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputTy.getDimSize(i)); - } - } - - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); - } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) - return failure(); - - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), - RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, - initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = stablehloReduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); - } - - if (keepDim) { - const auto &options = getOptions(); - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(outShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResult(0), outShapeTensor); - return success(); - } - rewriter.replaceOpWithNewOp(op, outTy, - stablehloReduceOp.getResults()); - return success(); -} -} // namespace - -// AtenFrobeniusNormDimOp -// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given -// dims) + stablehlo.sqrt -namespace { -template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenFrobeniusNormDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - const TorchToStablehloOptions &options = getOptions(); - - Value input = adaptor.getSelf(); - auto inputType = dyn_cast(input.getType()); - if (!inputType) { - return op.emitError( - "only ranked tensor input supported in AtenFrobeniusNormDimOp"); - } - auto inputRank = inputType.getRank(); - auto inputElemType = inputType.getElementType(); - if (!isa(inputElemType)) { - return op.emitError( - "only float dtype allowed in input tensor of AtenFrobeniusNormDimOp"); - } - - SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) { - return rewriter.notifyMatchFailure( - op, "non-const integer `dim` is not supported"); - } - for (auto &dim : dims) { - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) { - return rewriter.notifyMatchFailure(op, - "invalid dimension detected in `dim`"); - } - } - - // Sort the dims in ascending order, making the conversion - // stable with unordered dims. - std::sort(dims.begin(), dims.end()); - - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputRank; i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputType.getDimSize(i)); - } - } - - bool keepDim = false; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { - return rewriter.notifyMatchFailure( - op, "non-const bool `keepdim` is not supported"); - } - - auto squareOp = rewriter.create(op->getLoc(), input, input); - - auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); - if (!initValue) { - return failure(); - } - - auto reduceOp = rewriter.create( - op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), - squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = reduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputElemType); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); - } - - auto output = - rewriter.create(op->getLoc(), reduceOp.getResult(0)); - - if (keepDim) { - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); - if (failed(outShapeInfo)) { - return rewriter.notifyMatchFailure( - op, "failed to get dimension sizes of the input"); - } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), output, - outShapeTensor); - return success(); - } - rewriter.replaceOp(op, output.getResult()); - return success(); -} -} // namespace - -// AtenLinalgVectorNormOp +// AtenLinalgVectorNormOp namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenLinalgVectorNormOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = dyn_cast(input.getType()); @@ -1113,13 +813,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( std::sort(dims.begin(), dims.end()); } - std::unordered_set dimsSet(dims.begin(), dims.end()); - SmallVector reduceResultShape; - for (int64_t i = 0; i < inputType.getRank(); i++) { - if (dimsSet.find(i) == dimsSet.end()) { - reduceResultShape.push_back(inputType.getDimSize(i)); - } - } + SmallVector reduceResultShape = + getReduceOutputShape(inputType.getShape(), dims); bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { @@ -1127,71 +822,38 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - auto initValue = createInitialValueForReduceOp(op, outElemType, rewriter); - if (!initValue) { - return failure(); - } - Value absValue = rewriter.create(op->getLoc(), input); Value powValue = rewriter.create(op->getLoc(), absValue, ord, nullptr); - auto reduceOp = rewriter.create( - op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), - powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); - - Region ®ion = reduceOp.getBody(); - Block &block = region.emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, outElemType); - - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); - - auto firstArgument = *block.args_begin(); - auto secondArgument = *block.args_rbegin(); - - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - - auto addResult = rewriter.create( - op->getLoc(), firstArgument, secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims, + rewriter); + if (!reduceResult) { + return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } + + auto scalarType = RankedTensorType::get({}, outElemType); auto constantOne = rewriter.create( - op->getLoc(), blockArgumentTy, + op->getLoc(), scalarType, DenseElementsAttr::get( - blockArgumentTy, + scalarType, APFloat(cast(outElemType).getFloatSemantics(), 1))); auto reciprocalOrd = rewriter.create( - op->getLoc(), blockArgumentTy, constantOne, ord); - auto output = rewriter.create( - op->getLoc(), reduceOp.getResult(0), reciprocalOrd, nullptr); + op->getLoc(), scalarType, constantOne, ord); + Value output = rewriter.create( + op->getLoc(), reduceResult, reciprocalOrd, nullptr); if (keepDim) { - auto outShapeInfo = - hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } - auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - for (int64_t i : dims) { - outShapeVec[i] = one; - } - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), output, - outShapeTensor); - return success(); + output = reshapeReduceResultWhenKeepDim(rewriter, op->getLoc(), output, + *outShapeInfo, outType, dims); } - - rewriter.replaceOp(op, output.getResult()); + rewriter.replaceOp(op, output); return success(); } } // namespace @@ -1203,16 +865,43 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAllOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAnyOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxOp); - INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMinOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenFrobeniusNormDimOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp); #undef INSERT_ATEN_REDUCTION_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMaxOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenMinOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenSumOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenProdOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAllOp); + INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN(AtenAnyOp); +#undef INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAnyDimOp); + INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN(AtenAllDimOp); +#undef INSERT_ATEN_REDUCTION_ONE_DIM_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, options) + INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAmaxOp); + INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN(AtenAminOp); +#undef INSERT_ATEN_REDUCTION_DIMS_OP_PATTERN + +#define INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context, \ + options) + INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMaxDimOp); + INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN(AtenMinDimOp); +#undef INSERT_ATEN_REDUCTION_WITH_INDICES_PATTERN } diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp index 3cd440c957e9..340c5198bf11 100644 --- a/lib/Conversion/TorchToStablehlo/Rng.cpp +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -12,13 +12,10 @@ #include "../PassDetail.h" #include "./PopulatePatterns.h" -#include "mlir/IR/BuiltinTypes.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index c4d629d4f5bb..b22dc3e6ed30 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -9,11 +9,12 @@ #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include @@ -24,6 +25,31 @@ using namespace mlir::torch::Torch; namespace mlir { namespace hlo { +// Create chlo::ConstantLikeOp +template +Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, + Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + auto getAttr = [&]() -> Attribute { + if (isa(ty)) + return rewriter.getIntegerAttr(ty, constant); + if (isa(ty)) + return rewriter.getFloatAttr(ty, constant); + if (auto complexTy = dyn_cast(ty)) + return mlir::complex::NumberAttr::get(complexTy, constant, 0); + llvm_unreachable("unhandled element type"); + }; + return rewriter.create( + loc, cast(getAttr()), val); +} + +// Template instantiation +template Value getConstantLike(OpBuilder &rewriter, Location loc, + int64_t constant, Value val); + +template Value getConstantLike(OpBuilder &rewriter, Location loc, + double constant, Value val); + // Create a 32-bit float constant operator from a float Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, float val) { @@ -144,24 +170,24 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, } Value promoteType(PatternRewriter &rewriter, Location loc, Value input, - TensorType outType) { - TensorType in_type = cast(input.getType()); - - if (in_type.getElementType() != outType.getElementType()) { - TensorType promotedType = - in_type.cloneWith(in_type.getShape(), outType.getElementType()); - return rewriter.create(loc, promotedType, input); + Type outElementType) { + TensorType inType = cast(input.getType()); + if (inType.getElementType() != outElementType) { + return rewriter.create(loc, input, outElementType); } return input; } Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, - TensorType outType) { + TensorType outType, + std::optional bcastSizeTensor) { // Two tensors are “broadcastable” if the following rules hold: // - Each tensor has at least one dimension. // - When iterating over the dimension sizes, starting at the trailing // dimension, the dimension sizes must either be equal, one of them is 1, or // one of them does not exist. + // If one provide bcastSizeTensor, we emit stablehlo::DynamicBroadcastInDimOp + // instead of stablehlo::BroadcastInDimOp to support dynamic shape. Operation *op = input.getDefiningOp(); TensorType in_type = dyn_cast(input.getType()); @@ -199,6 +225,11 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, return input; } auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); + if (bcastSizeTensor.has_value()) { + auto bcast_op = rewriter.create( + op->getLoc(), outType, input, bcastSizeTensor.value(), bcast_attr); + return bcast_op.getResult(); + } auto bcast_op = rewriter.create( op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); @@ -253,9 +284,130 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, return getDimSizesOfTensor(rewriter, op, value, dims, dimSizeIndexBits); } +// Get the dimension sizes of the input tensor, given the dimension axes +FailureOr> +getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value, + ArrayRef inpDims) { + auto valueTy = dyn_cast(value.getType()); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimIndexOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + auto dims = toPositiveDims(inpDims, rank); + SmallVector dimSizes; + dimSizes.reserve(dims.size()); + + auto loc = op->getLoc(); + for (auto d : dims) { + dimSizes.emplace_back(rewriter.create(loc, value, d)); + } + return dimSizes; +} + +// Get the dimension sizes of the input tensor +FailureOr> +getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value) { + auto valueTy = dyn_cast(value.getType()); + if (!valueTy) { + return rewriter.notifyMatchFailure( + op, "getDimIndexOfTensor(): the input is not a ranked tensor"); + } + + auto rank = valueTy.getRank(); + // Get int vector [0, 1, ..., rank-1] + std::vector dims(rank); + std::iota(dims.begin(), dims.end(), 0); + return getDimIndexOfTensor(rewriter, op, value, dims); +} + +FailureOr>> +getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, + ArrayRef tensors, size_t dimSizeIndexBits) { + SmallVector> tensorSizes; + + int maxRank = 0; + for (auto tensor : tensors) { + auto tensorType = cast(tensor.getType()); + auto tensorRank = tensorType.getRank(); + + tensorSizes.emplace_back(tensorType.getShape()); + maxRank = std::max(maxRank, static_cast(tensorRank)); + } + + SmallVector bcastSizeTensors; + SmallVector bcastSizes; + for (int outDim = 0; outDim < maxRank; ++outDim) { // loop dimensions. + int dynamicDimCnt = 0; + int staticDimCnt = 0; + int64_t dimSize = -1; + Value dimSizeTensor = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); + + for (size_t i = 0; i < tensorSizes.size(); ++i) { // loop tensors. + int inDim = tensorSizes[i].size() - 1 - outDim; + if (inDim < 0) + continue; + + // dim size: 1 + if (tensorSizes[i][inDim] == 1) { + if (dimSize == -1) + dimSize = 1; + continue; + } + // dim size: dynamic + if (tensorSizes[i][inDim] == ShapedType::kDynamic || + tensorSizes[i][inDim] == kUnknownSize) { + dynamicDimCnt++; + dimSize = ShapedType::kDynamic; + auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( + rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); + if (failed(dimSizeTensorInfo)) { + return failure(); + } + dimSizeTensor = (*dimSizeTensorInfo)[0]; + continue; + } + // dim size: static + // we already found dynamic dim size, fail. + if (dynamicDimCnt > 0) { + return failure(); + } + // we already found static dim size not equal with this, fail. + if (staticDimCnt > 0 && dimSize != tensorSizes[i][inDim]) { + return failure(); + } + + staticDimCnt++; + dimSize = tensorSizes[i][inDim]; + auto dimSizeTensorInfo = hlo::getDimSizesOfTensor( + rewriter, op, tensors[i], {inDim}, dimSizeIndexBits); + if (failed(dimSizeTensorInfo)) { + return failure(); + } + dimSizeTensor = (*dimSizeTensorInfo)[0]; + } + + // TODO: Relax this check, by assuming all dynamic shape is same. + // if (dynamicDimCnt > 1) { + // return failure(); + // } + bcastSizes.push_back(dimSize); + bcastSizeTensors.push_back(dimSizeTensor); + } + std::reverse(bcastSizes.begin(), bcastSizes.end()); + std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); + return std::pair>( + rewriter.create(op->getLoc(), bcastSizeTensors) + .getResult(), + bcastSizes); +} + FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value tensor, ArrayRef inputUnsqzDims, - size_t dimSizeIndexBits) { + Value tensor, + ArrayRef inputUnsqzDims) { // Returns a new tensor with dims of size 1 inserted at the specified // position. // @@ -263,8 +415,7 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, // tensor) are specified with unsqzDims. Indices must be in-order, and in // range of tensor rank. Thus, unsqueeze a rank 1 tensor with {0, 2}, {0, 1, // 3}, {0, 1, 2} are all valid dimension sets, but {0, 3}, {2} are not. - auto dimSizesInfo = - getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -281,9 +432,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(dimSizeIndexBits); auto one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); std::vector newDimSizes; std::vector newShape; @@ -309,12 +459,9 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t collapseStartDim, - int64_t collapseEndDim, - size_t dimSizeIndexBits) { - - auto dimSizesInfo = - getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); + int64_t collapseEndDim) { + auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -330,7 +477,6 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(dimSizeIndexBits); std::vector newDimSizes; std::vector newShape; @@ -338,7 +484,7 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, newShape.reserve(newRank); Value collapseDimSize = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t collapseShape = 1; for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) { @@ -376,10 +522,8 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, // TODO: support splitDim & outerLength to be Value FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, Value tensor, int64_t splitDim, - int64_t outerLength, size_t dimSizeIndexBits) { - auto dimSizesInfo = - getDimSizesOfTensor(rewriter, op, tensor, dimSizeIndexBits); - + int64_t outerLength) { + auto dimSizesInfo = getDimIndexOfTensor(rewriter, op, tensor); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -391,7 +535,6 @@ FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - Type intType = rewriter.getIntegerType(dimSizeIndexBits); if (splitDim < 0 || splitDim >= rank) { return rewriter.notifyMatchFailure( @@ -400,7 +543,7 @@ FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, int64_t newRank = rank + 1; auto outerLengthValue = rewriter.create( - loc, rewriter.getIntegerAttr(intType, outerLength)); + loc, rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength)); auto innerLengthValue = rewriter.create( loc, dimSizes[splitDim], outerLengthValue); diff --git a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index 9a3360bf9069..ec9aa7a45493 100644 --- a/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -14,15 +14,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; @@ -57,7 +51,8 @@ class ConvertTorchToStablehlo TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversionForStablehlo(target, + typeConverter); RewritePatternSet patterns(context); diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 04952d84343a..71b675b5ea2a 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -17,12 +17,8 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include @@ -165,19 +161,77 @@ class ConvertAtenViewOp : public ConvertAtenOp { using ConvertAtenOp::ConvertAtenOp; using OpAdaptor = typename AtenOpT::Adaptor; + unsigned getBitWidth(Type type) const { + if (auto complexTy = dyn_cast(type)) + return 2 * getBitWidth(complexTy.getElementType()); + return type.getIntOrFloatBitWidth(); + } + LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rankType = dyn_cast(adaptor.getSelf().getType()); if (!rankType) - return op.emitError("Only ranked tensor types are currently supported"); + return op.emitError("Only ranked tensor types are currently supported."); + auto loc = op.getLoc(); + + // support AtenViewDtypeOp + if (isa(op)) { + auto self = adaptor.getSelf(); + auto baseResultTy = dyn_cast(op.getType()); + + // infer the result shape + auto operandElt = rankType.getElementType(); + auto targetElt = baseResultTy.getDtype(); + auto operandEltBitWidth = getBitWidth(operandElt); + auto targetEltBitWidth = getBitWidth(targetElt); + auto operandSizes = rankType.getShape(); + SmallVector castShape(operandSizes); + if (operandEltBitWidth > targetEltBitWidth) { + int64_t last_size = operandEltBitWidth / targetEltBitWidth; + castShape.push_back(last_size); + } else if (operandEltBitWidth < targetEltBitWidth) { + int64_t last_size = targetEltBitWidth / operandEltBitWidth; + if (!ShapedType::isDynamic(castShape.back()) and + last_size != castShape.back()) { + return rewriter.notifyMatchFailure( + op, "The last dim size is not equal to targetEltBitWidth / " + "operandEltBitWidth."); + } else { + castShape.pop_back(); + } + } + + auto resultType = + OpConversionPattern::getTypeConverter()->convertType( + baseResultTy); + if (!dyn_cast(resultType).hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "Currently only support static output shape."); + } + + auto castType = + baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype()); + auto cast = rewriter.create( + loc, + OpConversionPattern::getTypeConverter()->convertType( + castType), + self); + + auto reshape = + rewriter.create(loc, resultType, cast); + + rewriter.replaceOp(op, reshape); + + return success(); + } + // collect Value of dims SmallVector dimSizes; if (!getAtenViewOpSizes(op, adaptor, rewriter, dimSizes)) { return op.emitError("Dims size must be a list of Scalar"); } - auto loc = op.getLoc(); if (dimSizes.size() == 0 || rankType.getRank() == 0) { rewriter.replaceOpWithNewOp( op, @@ -187,6 +241,20 @@ class ConvertAtenViewOp : public ConvertAtenOp { return success(); } + // collect constant dim size which == -1 + SmallVector negOneIndex; + for (size_t i = 0; i < dimSizes.size(); i++) { + int64_t dim; + if (matchPattern(dimSizes[i], m_TorchConstantInt(&dim))) { + if (dim == -1) { + negOneIndex.push_back(i); + } + } + } + if (negOneIndex.size() > 1) { + return op.emitError("Only support at most one -1 in view target dims"); + } + std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { dSize = rewriter.create(loc, dSize).getResult(); return dSize; @@ -194,16 +262,29 @@ class ConvertAtenViewOp : public ConvertAtenOp { Value numel = rewriter.create( loc, rewriter.create(loc, adaptor.getSelf())); + numel = + rewriter.create(loc, rewriter.getI64Type(), numel); + + // note: assuming that -1 doesn't arise from dynamic value + if (negOneIndex.size() == 1) { + size_t index = negOneIndex[0]; + Value realDim = numel; + for (size_t i = 0; i < dimSizes.size(); i++) { + if (i != index) { + realDim = rewriter.create(loc, realDim, dimSizes[i]); + } + } + // update -1 to realDim + dimSizes[index] = realDim; + } Value stablehloShape = rewriter.create(loc, dimSizes); - Value computedShape = rewriter.create( - loc, stablehloShape.getType(), numel, stablehloShape); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), - adaptor.getSelf(), computedShape); + adaptor.getSelf(), stablehloShape); return success(); } @@ -212,6 +293,13 @@ class ConvertAtenViewOp : public ConvertAtenOp { SmallVector &dimSizes) const; }; +template <> +bool ConvertAtenViewOp::getAtenViewOpSizes( + AtenViewDtypeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, + SmallVector &dimSizes) const { + return false; +} + template <> bool ConvertAtenViewOp::getAtenViewOpSizes( AtenViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter, @@ -247,7 +335,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto getOptionalVal = [&](Value val) -> std::optional { - if (val.getType().isa()) { + if (isa(val.getType())) { return std::nullopt; } else { return val; @@ -299,8 +387,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -351,8 +438,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = hlo::getDimIndexOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -382,8 +468,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), - {dim}, options.dimSizeIndexBits); + auto unsqzTensorInfo = + hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), {dim}); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); @@ -414,8 +500,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant end is currently supported"); - auto collapseTensorInfo = hlo::collapseTensor( - rewriter, op, adaptor.getA(), start, end, options.dimSizeIndexBits); + auto collapseTensorInfo = + hlo::collapseTensor(rewriter, op, adaptor.getA(), start, end); if (failed(collapseTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create collapsed tensor"); @@ -427,7 +513,7 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( PrimsSplitDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto selfType = adaptor.getA().getType().dyn_cast(); + auto selfType = dyn_cast(adaptor.getA().getType()); if (!selfType) { return op.emitError("only tensor types are currently supported"); } @@ -445,8 +531,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only constant outerLength is currently supported"); - auto splitTensorInfo = hlo::splitTensor( - rewriter, op, adaptor.getA(), dim, outerLength, options.dimSizeIndexBits); + auto splitTensorInfo = + hlo::splitTensor(rewriter, op, adaptor.getA(), dim, outerLength); if (failed(splitTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create split tensor"); @@ -474,6 +560,7 @@ void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( #define INSERT_VIEW_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context, options) + INSERT_VIEW_OP_PATTERN(AtenViewDtypeOp); INSERT_VIEW_OP_PATTERN(AtenViewOp); INSERT_VIEW_OP_PATTERN(AtenReshapeOp); #undef INSERT_VIEW_OP_PATTERN diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 40367138bd27..1e9d63b63af5 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -11,16 +11,10 @@ #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/ValueRange.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" @@ -153,8 +147,8 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, } // Replace the original index with the index specified // by the scatter. - yieldVals[dim] = b.create( - loc, rewriter.getI32Type(), extractIndexValue); + yieldVals[dim] = convertScalarToDtype( + rewriter, loc, extractIndexValue, rewriter.getI32Type()); yieldVals.push_back(extractSrcValue); b.create(loc, yieldVals); }) @@ -260,6 +254,44 @@ static Value createTMTensorScanOp( return scanOp->getResult(0); } +static FailureOr createIntOrFloatCompareOp(PatternRewriter &rewriter, + Location loc, + Type elementType, Value lhs, + Value rhs, bool isDescending, + bool isEqual) { + + Value compareOp; + if (auto intType = dyn_cast(elementType)) { + // Case for using arith::CmpIOp. + arith::CmpIPredicate g = + isEqual ? arith::CmpIPredicate::sge : arith::CmpIPredicate::sgt; + arith::CmpIPredicate l = + isEqual ? arith::CmpIPredicate::sle : arith::CmpIPredicate::slt; + if (intType.isUnsignedInteger()) { + g = isEqual ? arith::CmpIPredicate::uge : arith::CmpIPredicate::ugt; + l = isEqual ? arith::CmpIPredicate::ule : arith::CmpIPredicate::ult; + } + arith::CmpIPredicate predicate = isDescending ? g : l; + compareOp = rewriter.create(loc, predicate, lhs, rhs); + return compareOp; + } + + if (isa(elementType)) { + // Case for using arith::CmpFOp. + arith::CmpFPredicate g = + isEqual ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OGT; + arith::CmpFPredicate l = + isEqual ? arith::CmpFPredicate::OLE : arith::CmpFPredicate::OLT; + + arith::CmpFPredicate predicate = isDescending ? g : l; + compareOp = rewriter.create(loc, predicate, lhs, rhs); + return compareOp; + } + + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); +} + // Utility function to create a TMTensor::SortOp. static FailureOr> createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, @@ -286,45 +318,122 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, } // Step 3. Create comparison op which will be used as the sorting predicate. - Value compareOp; - if (auto intType = dyn_cast(elementTypes[0])) { - // Case for using arith::CmpIOp. - arith::CmpIPredicate ge = arith::CmpIPredicate::sge; - arith::CmpIPredicate le = arith::CmpIPredicate::sle; - if (intType.isUnsignedInteger()) { - ge = arith::CmpIPredicate::uge; - le = arith::CmpIPredicate::ule; - } - arith::CmpIPredicate predicate = isDescending ? ge : le; - compareOp = rewriter.create( - loc, predicate, block->getArgument(0), block->getArgument(1)); - } else if (elementTypes[0].isa()) { - // Case for using arith::CmpFOp. - arith::CmpFPredicate predicate = - isDescending ? arith::CmpFPredicate::OGE : arith::CmpFPredicate::OLE; - compareOp = rewriter.create( - loc, predicate, block->getArgument(0), block->getArgument(1)); - } else { + auto compareOpRetVal = createIntOrFloatCompareOp( + rewriter, loc, elementTypes[0], block->getArgument(0), + block->getArgument(1), isDescending, true); + + if (failed(compareOpRetVal)) return rewriter.notifyMatchFailure( - sortOpLoc, "Only Integer and Floating element type expected."); - } + loc, "Only Integer and Floating element type expected."); // Step 4. Create yield op for yielding the sorting predicate. - rewriter.create(loc, compareOp); + rewriter.create(loc, compareOpRetVal.value()); return SmallVector(sortOp.getResults()); } +static FailureOr> createTMTensorTopkOp( + PatternRewriter &rewriter, Location topkOpLoc, llvm::ArrayRef inputs, + llvm::ArrayRef outputs, llvm::ArrayRef elementTypes, + int64_t dimension, bool isMinK) { + + // Generate output types. + SmallVector topkResultTypes; + for (Value val : outputs) { + topkResultTypes.push_back(val.getType()); + } + + // Create empty TopkOp, add body later. + auto topkOp = rewriter.create( + topkOpLoc, topkResultTypes, inputs, outputs, + rewriter.getI64IntegerAttr(dimension)); + + Region *body = &topkOp.getRegion(); + Block *block = rewriter.createBlock(body); + Location loc = body->getLoc(); + // Add arguments for each passed body region element type. + for (Type elementType : elementTypes) { + block->addArgument({elementType}, {loc}); + } + + // Generate compare operator. If minK is chosen, isDescending should be false. + // Is equal should be false, because we do not want equality to cause element + // swap. + auto compareOpRetVal = createIntOrFloatCompareOp( + rewriter, loc, elementTypes[0], block->getArgument(0), + block->getArgument(1), /*isDescending=*/!isMinK, /*isEqual=*/false); + + // Check if correct element types are passed. + if (failed(compareOpRetVal)) + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); + + // Yield the comparison result. + rewriter.create(loc, compareOpRetVal.value()); + return SmallVector(topkOp.getResults()); +} + +static FailureOr +repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter, + Type resType, Value self, int64_t repeats, + int64_t dim) { + Location loc = op->getLoc(); + auto context = op->getContext(); + auto selfTy = cast(self.getType()); + + int64_t inputRank = selfTy.getSizes().size(); + dim = toPositiveDim(dim, inputRank); + Value dimValue = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + Value dimValuePlusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim + 1)); + + auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne); + if (failed(unsqueezedInfo)) + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor op"); + self = *unsqueezedInfo; + + Value constMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + SmallVector expandShapeValueList(inputRank + 1, constMinusOne); + expandShapeValueList[dim + 1] = + rewriter.create(loc, rewriter.getI64IntegerAttr(repeats)); + Value expandShapeList = rewriter.create( + loc, ListType::get(IntType::get(context)), expandShapeValueList); + + SmallVector expandShape(inputRank + 1); + for (int64_t i = 0; i <= dim; i++) { + expandShape[i] = selfTy.getSizes()[i]; + } + expandShape[dim + 1] = repeats; + for (int64_t i = dim + 1; i < inputRank; i++) { + expandShape[i + 1] = selfTy.getSizes()[i]; + } + + BaseTensorType expandTy = + rewriter.getType(expandShape, selfTy.getOptionalDtype()); + Value expandSelf = + rewriter.create(loc, expandTy, self, expandShapeList); + + Value result = rewriter.create(loc, resType, expandSelf, + dimValue, dimValuePlusOne); + return result; +} + namespace { -class ConvertAtenScatterSrcOp : public OpConversionPattern { +template +class ConvertAtenScatterOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult - matchAndRewrite(AtenScatterSrcOp op, OpAdaptor adaptor, + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); Location loc = op.getLoc(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = + OpConversionPattern::getTypeConverter(); Value self = adaptor.getSelf(); Value index = adaptor.getIndex(); Value src = adaptor.getSrc(); @@ -352,11 +461,23 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { - b.create(loc, updatesElement); + if (isa(op)) { + b.create(loc, updatesElement); + } else if (isa(op)) { + if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } else if (isa(selfType.getElementType())) { + Value add = + b.create(loc, inputElement, updatesElement); + b.create(loc, add); + } + } }); - auto resultType = typeConverter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); } @@ -387,7 +508,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { // Check whether the input is a 1-d tensor of integer type or not. RankedTensorType inputType = cast(input.getType()); if (inputType.getRank() != 1 || - !inputType.getElementType().isa()) + !isa(inputType.getElementType())) return rewriter.notifyMatchFailure( op, "Input tensor has to be a one-dimensional tensor of integer type."); @@ -401,7 +522,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { "Unimplemented: Integer width not equal to 64 are not supported."); // TODO: Incorporate the weight argument. - if (!weights.getType().isa()) + if (!isa(weights.getType())) return rewriter.notifyMatchFailure( op, "Unimplemented: the weights operand is not incorporated."); @@ -445,8 +566,8 @@ class ConvertAtenBincountOp : public OpConversionPattern { indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); - auto resultType = typeConverter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); Type resultElemType = resultType.getElementType(); SmallVector inputSizeDynamic = @@ -483,19 +604,9 @@ class ConvertAtenBincountOp : public OpConversionPattern { namespace { -Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, - OpBuilder b) { - llvm::SmallVector indices(indicesRef); - // Declare commonly used constants up front: - Value torchCstZero = - b.create(loc, b.getI64IntegerAttr(0)); - Value torchCstOne = - b.create(loc, b.getI64IntegerAttr(1)); - Value torchCstNegOne = - b.create(loc, b.getI64IntegerAttr(-1)); - - // Determine the broadcast sizes and materialize missing implicit end - // dimensions: +// Determine the common broadcast shape of all the index tensors. +std::pair, llvm::SmallVector> +getBroadcastShape(Location loc, llvm::ArrayRef indices, OpBuilder b) { int64_t indicesRank = 0; for (auto index : indices) { auto indexTy = cast(index.getType()); @@ -509,6 +620,8 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, return std::max(dim0, dim1); }; + Value torchCstOne = + b.create(loc, b.getI64IntegerAttr(1)); llvm::SmallVector broadcastSizes(indicesRank, torchCstOne); llvm::SmallVector broadcastShape(indicesRank, 0); for (auto index : indices) { @@ -527,6 +640,21 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, broadcastShape[idx] = maxDim(size, broadcastShape[idx]); } } + return std::make_pair(broadcastSizes, broadcastShape); +} + +Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, + OpBuilder b) { + llvm::SmallVector indices(indicesRef); + // Declare commonly used constants up front: + Value torchCstZero = + b.create(loc, b.getI64IntegerAttr(0)); + Value torchCstOne = + b.create(loc, b.getI64IntegerAttr(1)); + Value torchCstNegOne = + b.create(loc, b.getI64IntegerAttr(-1)); + + auto [broadcastSizes, broadcastShape] = getBroadcastShape(loc, indicesRef, b); auto mulDim = [](int64_t dim0, int64_t dim1) { if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) @@ -675,6 +803,34 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, return b.create(loc, valuesTy, values, outDimsList); } +// Broadcast the `values` tensor to the slice size created by the list of index +// tensors. +static Value broadcastValuesToSliceSize(Location loc, Value input, Value values, + llvm::ArrayRef indices, + OpBuilder b) { + auto inputType = cast(input.getType()); + ArrayRef inputStaticShape = inputType.getSizes(); + auto valuesType = cast(values.getType()); + + // In the case where the input rank is greater than the number of index + // tensors, the remaining dimensions of the input are indexed in their + // entirety. Thus, we need to append the remaining dimensions to get the shape + // of the indexed slice. + auto [resultShape, resultStaticShape] = getBroadcastShape(loc, indices, b); + for (size_t i = indices.size(); i < inputStaticShape.size(); i++) { + Value dim = b.create(loc, b.getI64IntegerAttr(i)); + resultShape.push_back(b.create(loc, input, dim)); + resultStaticShape.push_back(inputStaticShape[i]); + } + + auto resultType = b.getType( + resultStaticShape, valuesType.getOptionalDtype()); + Value broadcastShapeList = b.create( + loc, Torch::ListType::get(b.getType()), resultShape); + return b.create(loc, resultType, values, + broadcastShapeList); +} + class ConvertAtenIndexPutHackedTwinOp : public OpConversionPattern { public: @@ -692,8 +848,8 @@ class ConvertAtenIndexPutHackedTwinOp auto valuesType = cast(values.getType()); int64_t inputRank = inputType.getSizes().size(); auto valuesTensorType = cast(op.getValues().getType()); - auto resultType = typeConverter->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); if (!valuesTensorType.hasSizes()) return rewriter.notifyMatchFailure( @@ -722,6 +878,8 @@ class ConvertAtenIndexPutHackedTwinOp if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); + values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList, + rewriter); // Filter to available indices and get the indicesMap: SmallVector indicesList; SmallVector indicesMap; @@ -822,17 +980,20 @@ class ConvertAtenIndexPutHackedTwinOp // 2.) `values` is mapped to `updates` in scatter op. // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; + // If accumulate == false, the behavior is undefined if the indicies aren't + // unique. + bool uniqueIndices = !accumulate; Value scatterOp = createTMTensorScatterOp( rewriter, loc, values, indices, input, indicesMap, - /*uniqueIndices=*/false, + /*uniqueIndices=*/uniqueIndices, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; if (accumulate) { - if (inputElement.getType().isa()) { + if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); - } else if (inputElement.getType().isa()) { + } else if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); } else { @@ -1048,10 +1209,10 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; - if (inputElement.getType().isa()) { + if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); - } else if (inputElement.getType().isa()) { + } else if (isa(inputElement.getType())) { yieldValue = b.create(loc, inputElement, valuesElement); } else { @@ -1210,33 +1371,33 @@ class ConvertAtenScatterReduceTwoOp Value result; if (reduceEnum == torch_upstream::ReductionType::SUM || reduceEnum == torch_upstream::ReductionType::MEAN) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::PROD) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::MAX) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::MIN) { - if (update.getType().isa()) { + if (isa(update.getType())) { result = b.create(loc, update, current); - } else if (update.getType().isa()) { + } else if (isa(update.getType())) { result = b.create(loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); @@ -1291,9 +1452,8 @@ class ConvertAtenScatterReduceTwoOp }) .getResult()[0]; } - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); @@ -1388,6 +1548,79 @@ class ConvertAtenSortOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenCumprodOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + Type elementType = resultType.getElementType(); + Type inputElementType = + cast(input.getType()).getElementType(); + + // Converting the input element type to the result's element type. + // The only possible mismatch would be when the input element type is an + // integer but not `si64`. Therefore, we directly convert the input to + // `si64`. Rest all cases are handled in the dtype definition for this op. + if (elementType != inputElementType) { + Value torchInput = convertTensorToDtype( + rewriter, loc, op.getSelf(), + rewriter.getIntegerType(64, IntegerType::Signed)); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(torchInput.getType()), + torchInput); + } + + int64_t inputRank = resultType.getRank(); + Value dtype = op.getDtype(); + if (!isa(dtype.getType())) + return rewriter.notifyMatchFailure( + op, "unsupported: dtype argument not supported"); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "invalid dim"); + + SmallVector sizes = getTensorSizes(rewriter, loc, input); + Value output = createOneInitTensor(rewriter, loc, sizes, elementType); + output = rewriter.create(loc, resultType, output); + + SmallVector accSizes(sizes); + accSizes.erase(accSizes.begin() + dim); + SmallVector accStatic( + makeShapeTorchCompatible(resultType.getShape())); + accStatic.erase(accStatic.begin() + dim); + Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType); + Type accType = + RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); + acc = rewriter.create(loc, accType, acc); + + Value result = createTMTensorScanOp( + rewriter, loc, input, output, acc, dim, /*inclusive=*/true, + [](OpBuilder &b, Location loc, Value input, Value acc) { + Value prod = + (isa(input.getType()) + ? b.create(loc, input, acc)->getResult(0) + : b.create(loc, input, acc)->getResult(0)); + b.create(loc, prod); + }); + + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: @@ -1398,9 +1631,8 @@ class ConvertAtenCumsumOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); Type inputElementType = cast(input.getType()).getElementType(); @@ -1420,7 +1652,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern { int64_t inputRank = resultType.getRank(); Value dtype = op.getDtype(); - if (!dtype.getType().isa()) + if (!isa(dtype.getType())) return rewriter.notifyMatchFailure( op, "unsupported: dtype argument not supported"); @@ -1450,7 +1682,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern { rewriter, loc, input, output, acc, dim, /*inclusive=*/true, [](OpBuilder &b, Location loc, Value input, Value acc) { Value sum = - (input.getType().isa() + (isa(input.getType()) ? b.create(loc, input, acc)->getResult(0) : b.create(loc, input, acc)->getResult(0)); b.create(loc, sum); @@ -1467,44 +1699,210 @@ class ConvertAtenScaledDotProductAttentionOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + + static LogicalResult + preProcessGroupQueryAttentionInput(AtenScaledDotProductAttentionOp op, + ConversionPatternRewriter &rewriter, + const TypeConverter *typeConverter, + Value query, Value &key, Value &value) { + auto queryTy = cast(query.getType()); + auto valueTy = cast(value.getType()); + auto keyTy = cast(key.getType()); + + int64_t rank = queryTy.getRank(); + + int64_t qNumHeads = queryTy.getDimSize(rank - 3); + int64_t kNumHeads = valueTy.getDimSize(rank - 3); + int64_t vNumHeads = keyTy.getDimSize(rank - 3); + + if (llvm::any_of(llvm::ArrayRef{qNumHeads, kNumHeads, vNumHeads}, + [](int64_t d) { return d == Torch::kUnknownSize; })) { + return llvm::failure(); + } + + if (llvm::all_equal( + llvm::ArrayRef{qNumHeads, kNumHeads, vNumHeads})) + return llvm::success(); + + if ((qNumHeads % kNumHeads) && (qNumHeads % vNumHeads)) + return llvm::failure(); + + int64_t repeatKeyShape = qNumHeads / kNumHeads; + int64_t repeatValueShape = qNumHeads / vNumHeads; + + Location loc = op.getLoc(); + FailureOr keyRepeated = repeatTensorElementsForDim( + op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(), + op.getKey(), + /*repeats=*/repeatKeyShape, /*dim=*/rank - 3); + if (failed(keyRepeated)) + return rewriter.notifyMatchFailure( + loc, "Failed to repeat the tensor elements for key."); + + FailureOr valueRepeated = repeatTensorElementsForDim( + op.getOperation(), rewriter, /*resType=*/op.getQuery().getType(), + op.getValue(), + /*repeats=*/repeatValueShape, /*dim=*/rank - 3); + if (failed(valueRepeated)) + return rewriter.notifyMatchFailure( + loc, "Failed to repeat the tensor elements for value."); + + key = typeConverter->materializeTargetConversion( + rewriter, loc, + typeConverter->convertType(keyRepeated.value().getType()), + keyRepeated.value()); + value = typeConverter->materializeTargetConversion( + rewriter, loc, + typeConverter->convertType(valueRepeated.value().getType()), + valueRepeated.value()); + return success(); + } + LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value mask = op.getAttnMask(); + + auto opTy = cast(op.getType()).toBuiltinTensor(); + auto query = adaptor.getQuery(); + auto value = adaptor.getValue(); + auto key = adaptor.getKey(); + auto mask = adaptor.getAttnMask(); + auto queryTy = cast(query.getType()); + auto valueTy = cast(value.getType()); + auto keyTy = cast(key.getType()); + + auto loc = op.getLoc(); Value dropoutP = op.getDropoutP(); Value isCausal = op.getIsCausal(); Value scale = op.getScale(); + Value enableGQA = op.getEnableGqa(); Type elementType = cast(adaptor.getQuery().getType()).getElementType(); - // Verify inputs (only support defaults) - if (!mask.getType().isa()) - return rewriter.notifyMatchFailure(op.getLoc(), - "attention masking not supported"); double dropout; if (!matchPattern(dropoutP, m_TorchConstantFloat(&dropout)) || dropout > 0.0) - return rewriter.notifyMatchFailure(op.getLoc(), "dropout not supported"); + return rewriter.notifyMatchFailure(loc, "dropout not supported"); + bool causal; - if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) - return rewriter.notifyMatchFailure( - op.getLoc(), "causal attention masking not supported"); - if (!scale.getType().isa()) { + if (!matchPattern(isCausal, m_TorchConstantBool(&causal)) || causal) { + if (!isa(mask.getType())) { + return rewriter.notifyMatchFailure( + loc, "expected no attention mask when isCausal is true"); + } + + SmallVector maskStatic; + SmallVector maskDyn; + for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) { + maskStatic.push_back(queryTy.getDimSize(i)); + if (maskStatic.back() == ShapedType::kDynamic) + maskDyn.push_back(rewriter.create(loc, query, i)); + } + + maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2)); + if (maskStatic.back() == ShapedType::kDynamic) + maskDyn.push_back( + rewriter.create(loc, key, keyTy.getRank() - 2)); + + Type maskType = getElementTypeOrSelf(queryTy); + Value emptyMask = + rewriter.create(loc, maskStatic, maskType, maskDyn); + + Value zero = rewriter.create( + loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0)); + Value negInf = rewriter.create( + loc, + rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY)); + + mask = rewriter.create(loc, zero, emptyMask).getResult(0); + + int64_t rank = cast(queryTy).getRank(); + AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + auto genericOp = rewriter.create( + loc, mask.getType(), ValueRange{}, mask, + SmallVector{maskMap}, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value i = b.create(loc, queryTy.getRank() - 2); + Value j = b.create(loc, queryTy.getRank() - 1); + + Value cond = + b.create(loc, arith::CmpIPredicate::sge, i, j); + Value select = b.create(loc, cond, zero, negInf); + b.create(loc, select); + }); + mask = genericOp.getResult(0); + } + + // Broadcast the batch dimensions of the mask: + if (!isa(mask.getType())) { + auto maskTy = cast(mask.getType()); + int64_t rank = maskTy.getRank(); + bool needsBroadcast = false; + for (int i = 0, s = rank - 2; i < s; ++i) { + needsBroadcast |= maskTy.getDimSize(i) != keyTy.getDimSize(i); + } + + if (needsBroadcast) { + SmallVector maskShape; + SmallVector maskDynDims; + + SmallVector maskExprs; + for (int i = 0, s = rank - 2; i < s; ++i) { + maskShape.push_back(keyTy.getDimSize(i)); + + if (maskTy.getDimSize(i) != keyTy.getDimSize(i)) { + maskExprs.push_back(rewriter.getAffineConstantExpr(0)); + } else { + maskExprs.push_back(rewriter.getAffineDimExpr(i)); + } + + if (keyTy.isDynamicDim(i)) { + maskDynDims.push_back(rewriter.create(loc, key, i)); + } + } + + maskExprs.push_back(rewriter.getAffineDimExpr(rank - 2)); + maskExprs.push_back(rewriter.getAffineDimExpr(rank - 1)); + maskShape.push_back(maskTy.getDimSize(rank - 2)); + maskShape.push_back(maskTy.getDimSize(rank - 1)); + if (maskTy.isDynamicDim(rank - 2)) + maskDynDims.push_back( + rewriter.create(loc, mask, rank - 2)); + if (maskTy.isDynamicDim(rank - 1)) + maskDynDims.push_back( + rewriter.create(loc, mask, rank - 1)); + + SmallVector affineMaps = { + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs, + op.getContext()), + rewriter.getMultiDimIdentityMap(rank)}; + SmallVector findMaxIteratorTypes( + rank, utils::IteratorType::parallel); + + Value emptyMask = rewriter.create( + loc, maskShape, maskTy.getElementType(), maskDynDims); + Value newMask = + rewriter + .create( + loc, emptyMask.getType(), mask, ValueRange({emptyMask}), + affineMaps, findMaxIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + mask = newMask; + } + } + + if (!isa(scale.getType())) { double scaleFloat; if (!matchPattern(scale, m_TorchConstantFloat(&scaleFloat)) || scaleFloat != 1.0) - return rewriter.notifyMatchFailure(op.getLoc(), - "only default scale supported"); + return rewriter.notifyMatchFailure(loc, "only default scale supported"); } - auto opTy = cast(op.getType()).toBuiltinTensor(); - auto query = adaptor.getQuery(); - auto value = adaptor.getValue(); - auto key = adaptor.getKey(); - auto queryTy = cast(query.getType()); - auto valueTy = cast(value.getType()); - auto keyTy = cast(key.getType()); - if (queryTy.getRank() != valueTy.getRank() || queryTy.getRank() != keyTy.getRank()) return rewriter.notifyMatchFailure(op, "operand ranks do not match"); @@ -1512,13 +1910,28 @@ class ConvertAtenScaledDotProductAttentionOp if (queryTy.getRank() < 3) return rewriter.notifyMatchFailure(op, "missing batch dimension"); + bool isGQAEnabled; + if (!matchPattern(enableGQA, m_TorchConstantBool(&isGQAEnabled))) + return rewriter.notifyMatchFailure( + loc, "Expected enable_gqa flag to be constant bool"); + + // For the cases when `enable_gqa` flag is set to true, we have to + // pre-process the inputs specifically key and value by repeating the + // elements for the head dim. + // The reference code is available here: + // https://github.com/pytorch/pytorch/pull/132689/files#diff-e726853e9795dfb6c74ab1e10945f5d5f24540eb7bc633e5c76f69bc258f24d6R612 + if (enableGQA) { + if (failed(preProcessGroupQueryAttentionInput( + op, rewriter, getTypeConverter(), query, key, value))) + return failure(); + } + llvm::SmallVector reassociation(3); for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) reassociation.front().push_back(i); reassociation[1].push_back(valueTy.getRank() - 2); reassociation[2].push_back(valueTy.getRank() - 1); - auto loc = op.getLoc(); auto collapseBatch = [&rewriter, &reassociation, loc](Value value) -> Value { auto valueTy = cast(value.getType()); @@ -1545,26 +1958,32 @@ class ConvertAtenScaledDotProductAttentionOp query = collapseBatch(query); key = collapseBatch(key); value = collapseBatch(value); + if (!isa(mask.getType())) { + mask = collapseBatch(mask); + } SmallVector outSizes(cast(query.getType()).getShape()); SmallVector valueSizes( cast(value.getType()).getShape()); outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; - SmallVector outSizesDynamic( - getTensorSizes(rewriter, op.getLoc(), query)); + SmallVector outSizesDynamic(getTensorSizes(rewriter, loc, query)); outSizesDynamic[outSizesDynamic.size() - 1] = - getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1]; + getTensorSizes(rewriter, loc, value)[valueSizes.size() - 1]; Type outType = RankedTensorType::get(outSizes, elementType); - Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, - elementType); + Value output = + createZeroInitTensor(rewriter, loc, outSizesDynamic, elementType); + + SmallVector inputs = SmallVector{query, key, value}; + + if (!isa(mask.getType())) { + inputs.push_back(mask); + } // Overwrite with tm_tensor::attention - Value attention = - rewriter - .create(loc, outType, - SmallVector{query, key, value}, - SmallVector{output}) - .getResult()[0]; + Value attention = rewriter + .create(loc, outType, inputs, + SmallVector{output}) + .getResult()[0]; if (opTy != outType) { attention = rewriter.create(loc, opTy, attention, @@ -1578,6 +1997,456 @@ class ConvertAtenScaledDotProductAttentionOp }; } // namespace +namespace { +class ConvertAtenKthvalueOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenKthvalueOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const llvm::StringRef opName = op->getName().getStringRef(); + + Location loc = op.getLoc(); + auto typec = this->getTypeConverter(); + + Value input = adaptor.getSelf(); + auto inputType = cast(input.getType()); + unsigned inputRank = inputType.getRank(); + Type inputElementType = inputType.getElementType(); + + auto valResultType = + cast(typec->convertType(op.getResult(0).getType())); + auto valResultElementType = + getElementTypeOrSelf(typec->convertType(valResultType)); + + auto idxResultType = + cast(typec->convertType(op.getResult(1).getType())); + auto idxResultElementType = + getElementTypeOrSelf(typec->convertType(idxResultType)); + + // get keepdim and check it is bool + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure( + op, opName + " requires boolean value for keepdim"); + + // get dim, check it is constant int + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + + // turn dim into positive if negative, and check it is in the valid range + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + } + + // get k, check it is a constant int + int64_t k; + if (!matchPattern(op.getK(), m_TorchConstantInt(&k))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant k value is supported"); + + // check if element type is float, int, or unsigned + bool isUnsigned = false; + if (!isa(inputElementType)) { + if (!isa(inputElementType)) { + return rewriter.notifyMatchFailure( + op, opName + " to linalg.* requires Float or Integer " + "input element type"); + } + + auto integerTy = dyn_cast( + cast(op.getSelf().getType()).getDtype()); + isUnsigned = integerTy.isUnsigned(); + } + + // Create the values to fill initial output tensors for + // topk op and linalg generic op for finding max value. + Value fillValLinalgFindMax; + Value fillValTopK; + if (isa(inputElementType)) { + // max float for topk tensor + fillValTopK = rewriter.create( + loc, + rewriter.getFloatAttr( + inputElementType, + APFloat::getInf( + cast(inputElementType).getFloatSemantics(), + /*Negative=*/false))); + // min float for linalg generic op tensor + fillValLinalgFindMax = rewriter.create( + loc, + rewriter.getFloatAttr( + inputElementType, + APFloat::getInf( + cast(inputElementType).getFloatSemantics(), + /*Negative=*/true))); + } else if (!isUnsigned) { + auto width = cast(inputElementType).getWidth(); + // max signed int for topk op tensor + auto init = APSInt::getSignedMaxValue(width); + fillValTopK = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + // min signed int for linalg generic op tensor + init = APSInt::getSignedMinValue(width); + fillValLinalgFindMax = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + } else if (isUnsigned) { + auto width = cast(inputElementType).getWidth(); + // max unsigned int for topk op tensor + auto init = APInt::getMaxValue(width); + fillValTopK = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + // min unsigned int for linalg generic op tensor + init = APInt::getMinValue(width); + fillValLinalgFindMax = rewriter.create( + loc, rewriter.getIntegerAttr(inputElementType, init)); + } + + auto i32Type = rewriter.getI32Type(); + + // ======== BEGIN: Topk op section ======== + // Based on iree docs: + // https://iree.dev/reference/mlir-dialects/LinalgExt/#iree_linalg_extsort-linalgextsortop + + // Create the output shape of topk op. + // For every dimension, topkShape[dimension] = inputShape[dimension], + // except topkShape[dim] = k. + SmallVector topkShape; + for (unsigned i = 0; i < inputRank; i++) { + auto currentDimSize = rewriter.create(loc, input, i); + topkShape.push_back(currentDimSize); + } + auto dimSize = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k)); + topkShape[dim] = dimSize; + + // Fill the initial topk op output tensor. + Value topkOutputVal = createInitTensor(rewriter, loc, topkShape, + valResultElementType, fillValTopK); + + // Create the initial value to fill the topk output indices tensor. + // It is equal to the max 32-bit signless integer. + auto signlessType = mlir::IntegerType::get(op.getContext(), 32, + mlir::IntegerType::Signless); + auto initIdx = getNumericLimit(rewriter, signlessType, /*getMin=*/false); + auto fillValTopkIdx = rewriter.create(loc, initIdx); + // Fill the initial topk op output indices tensor. + Value topkOutputIdx = + createInitTensor(rewriter, loc, topkShape, i32Type, fillValTopkIdx); + + // Input arguments for topk op contain only the input tensor. + // Input indices will be inferred based on input shape. + // (See docs link above). + SmallVector topkInputs; + topkInputs.push_back(input); + + // Outputs contain both the values and the indices tensors. + SmallVector topkOutputs; + topkOutputs.push_back(topkOutputVal); + topkOutputs.push_back(topkOutputIdx); + + // Element types of the arguments passed to the topk op region. + // The region accepts the next value N, and the current output + // candidate K (see docs link above). + // Both N and K are values from the input tensors, thus the + // element types are the same and are taken from inputType. + SmallVector topkElementTypes; + topkElementTypes.push_back(inputType.getElementType()); + topkElementTypes.push_back(inputType.getElementType()); + + // Create the TMTensor TopkOp. + FailureOr> topkOp; + { + OpBuilder::InsertionGuard guard(rewriter); + topkOp = createTMTensorTopkOp(rewriter, loc, topkInputs, topkOutputs, + topkElementTypes, dim, /*isMinK=*/true); + } + // Topk op creation fails with invalid element types. + if (failed(topkOp)) + return rewriter.notifyMatchFailure( + loc, "Only Integer and Floating element type expected."); + + auto topkOpVal = topkOp.value(); + // ======== END: Topk op section ======== + + // ======== BEGIN: Linalg generic to find max in topk result ======== + + // Create result shape as both a vector of Value and of int64_t types. + // We assume that keepdim is false, and fix the result later if true. + // Result shape is equal to inputShape, with dim dimension removed. + SmallVector resultShape; + SmallVector resultShapeInt; + for (int64_t i = 0; i < inputType.getRank(); i++) { + if (dim != i) { + auto currentDimSize = rewriter.create(loc, input, i); + resultShape.push_back(currentDimSize); + resultShapeInt.push_back(inputType.getShape()[i]); + } + } + + // Fill the initial output tensor for linalg op for finding max value. + Value findMaxOutputVal = createInitTensor( + rewriter, loc, resultShape, inputElementType, fillValLinalgFindMax); + + // Fill the initial output indices tensor for linalg op for finding max + // value with zeros. + Value findMaxOutputIdx = + createZeroInitTensor(rewriter, loc, resultShape, idxResultElementType); + + // Reduce along dim. + SmallVector findMaxIteratorTypes( + inputType.getRank(), utils::IteratorType::parallel); + findMaxIteratorTypes[dim] = utils::IteratorType::reduction; + + SmallVector findMaxMapExprs; + SmallVector findMaxMapResultExprs; + for (auto size : + llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { + findMaxMapExprs.push_back(rewriter.getAffineDimExpr(size.index())); + if (unsigned(dim) != size.index()) + findMaxMapResultExprs.push_back( + rewriter.getAffineDimExpr(size.index())); + } + + auto findMaxMaps = AffineMap::inferFromExprList( + {findMaxMapExprs, findMaxMapResultExprs, findMaxMapResultExprs}, + rewriter.getContext()); + + // Create linalg op for finding the max value in the extracted topk values. + auto findMaxLinalg = rewriter.create( + loc, + ArrayRef( + {findMaxOutputVal.getType(), findMaxOutputIdx.getType()}), + topkOpVal.front(), ValueRange({findMaxOutputVal, findMaxOutputIdx}), + findMaxMaps, findMaxIteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + // Linalg generic body is the same as the decomposition for + // AtenMinDim: lib/Conversion/TorchToLinalg/Reduction.cpp + + Value newValue = blockArgs[0]; + Value oldValue = blockArgs[1]; + Value oldIndex = blockArgs[2]; + + Value newIndex = rewriter.create( + nestedLoc, oldIndex.getType(), + rewriter.create(nestedLoc, dim)); + + Value resultVal, predicate; + if (isa(inputElementType)) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + predicate = rewriter.create( + nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + } else { + arith::CmpIPredicate predType; + predType = isUnsigned ? arith::CmpIPredicate::ugt + : arith::CmpIPredicate::sgt; + if (isUnsigned) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } + predicate = rewriter.create(nestedLoc, predType, + newValue, oldValue); + } + auto resultIndex = rewriter.create( + nestedLoc, predicate, newIndex, oldIndex); + nestedBuilder.create( + nestedLoc, ValueRange{resultVal, resultIndex}); + }); + + auto findMaxVal = findMaxLinalg.getResult(0); + auto findMaxIdx = findMaxLinalg.getResult(1); + auto findMaxIdxType = cast(findMaxIdx.getType()); + + // ======== END: Linalg generic to find max in topk result ======== + + // ======== BEGIN: Linalg generic for index extraction ======== + // The linalg op for finding max returned idx of max elements in the + // tensor returned by the topk op. We need the idx of those elements + // in the original input. The topk op returned the idx of the top k + // extracted elements in the original input. Using the linalg idx + // results to index the topk idx results returns the idx of kth + // max value in the original input. Example: + // input = [1, 7, 3, 6, 2, 8, 9, 5], k = 4 + // topk_val = [1, 3, 2, 5], topk_idx = [0, 2, 4, 7] + // linalg_max_val = [5], linalg_max_idx = [3] (5 is at idx 3 in topk_val) + // index the topk_idx using linalg_max_idx -> topk_idx[3] = 7 + // kth_val = [5], kth_idx = [7] + + // Create a tensor for the resulting idx. + Value filledTensorExtractedIdx = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, findMaxIdx), i32Type); + + // We iterate through the idx tensor returned by the linalg generic op for + // finding max. + SmallVector extractedIdxIteratorTypes( + findMaxIdxType.getRank(), utils::IteratorType::parallel); + + SmallVector extractedIdxMapExprs; + for (auto size : + llvm::enumerate(makeShapeTorchCompatible(findMaxIdxType.getShape()))) { + extractedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index())); + } + + auto extractedIdxMaps = AffineMap::inferFromExprList( + {extractedIdxMapExprs, extractedIdxMapExprs}, rewriter.getContext()); + + // Linalg generic op for indexing the topk output idx tensor using + // the idx tensor returned by the linalg generic op for finding max. + // Only the idx tensor from the linalg generic op is sent as input. + auto extractedIdxLinalg = rewriter.create( + loc, ArrayRef({filledTensorExtractedIdx.getType()}), findMaxIdx, + filledTensorExtractedIdx, extractedIdxMaps, extractedIdxIteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + // Get the current input idx. + Value index = rewriter.create( + loc, rewriter.getIndexType(), blockArgs[0]); + + // Create idx to index the topk idx tensor. + // Index the dim dimension using the current input idx. + SmallVector indexTarget; + for (unsigned i = 0; i < dim; i++) + indexTarget.push_back(rewriter.create(loc, i)); + indexTarget.push_back(index); + for (unsigned i = dim; i < findMaxIdxType.getRank(); i++) + indexTarget.push_back(rewriter.create(loc, i)); + + // Extract the element from the topk idx tensor. + Value extractedElement = rewriter.create( + loc, topkOpVal.back(), indexTarget); + rewriter.create(loc, extractedElement); + }); + + auto extractedIdx = extractedIdxLinalg.getResult(0); + auto extractedIdxType = cast(extractedIdx.getType()); + + // ======== END: Linalg generic for index extraction ======== + + // ======== BEGIN: Linalg generic for topk idx cast ======== + // Casts from i32 to idx result type of the Kthvalue op. + + // Create the initial tensor for the cast result. + Value filledTensorCastedIdx = createZeroInitTensor( + rewriter, loc, getTensorSizes(rewriter, loc, extractedIdx), + idxResultElementType); + + SmallVector castedIdxIteratorTypes( + extractedIdxType.getRank(), utils::IteratorType::parallel); + + SmallVector castedIdxMapExprs; + for (auto size : llvm::enumerate( + makeShapeTorchCompatible(extractedIdxType.getShape()))) { + castedIdxMapExprs.push_back(rewriter.getAffineDimExpr(size.index())); + } + + auto castedIdxMaps = AffineMap::inferFromExprList( + {castedIdxMapExprs, castedIdxMapExprs}, rewriter.getContext()); + + // Linalg generic op for casting topk idx output tensor elements from i32 to + // result idx tensor element type. + auto castedIdxLinalg = rewriter.create( + loc, ArrayRef({filledTensorCastedIdx.getType()}), extractedIdx, + filledTensorCastedIdx, castedIdxMaps, castedIdxIteratorTypes, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + Value oldIdx = blockArgs[0]; + + // Cast from i32 to index. + Value oldIdxToIndexType = rewriter.create( + nestedLoc, rewriter.getIndexType(), oldIdx); + // Cast from index to result idx element type. + Value resultIdx = rewriter.create( + nestedLoc, idxResultElementType, oldIdxToIndexType); + + nestedBuilder.create(nestedLoc, resultIdx); + }); + + auto castedIdx = castedIdxLinalg.getResult(0); + + // ======== END: Linalg generic for topk idx cast ======== + + // Create output value type ("squeezed" since we assume keepdim=False). + auto topkValResultType = + cast(topkOpVal.front().getType()); + auto squeezedValType = topkValResultType.cloneWith( + resultShapeInt, + cast(findMaxVal.getType()).getElementType()); + + // Create output idx type ("squeezed" since we assume keepdim=False). + auto castedIdxType = cast(castedIdx.getType()); + auto squeezedIdxType = castedIdxType.cloneWith( + resultShapeInt, findMaxIdxType.getElementType()); + + if (!keepDim) { + // If keepdim=false, cast the the outputs to appropriate type and return. + Value retVal = + rewriter.create(loc, squeezedValType, findMaxVal); + Value retIdx = + rewriter.create(loc, squeezedIdxType, castedIdx); + llvm::SmallVector res{retVal, retIdx}; + rewriter.replaceOp(op, res); + return success(); + } + + // If keepdim is false, unsqueeze. + // Unsqueezing implementation taken from AteMinMaxDimOp lowering: + // lib/Conversion/TorchToLinalg/Reduction.cpp + llvm::SmallVector valShape(valResultType.getShape()); + llvm::SmallVector idxShape(idxResultType.getShape()); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i] = valShape[i + 1]; + idxShape[i] = idxShape[i + 1]; + } + + valShape.resize(valShape.size() - 1); + idxShape.resize(idxShape.size() - 1); + + Value retVal = rewriter.create( + loc, squeezedValType.clone(valShape), findMaxLinalg.getResult(0)); + Value retIdx = rewriter.create( + loc, squeezedIdxType.clone(idxShape), castedIdx); + + SmallVector reassociation(valShape.size()); + if (reassociation.size() > 0) { + for (int i = 0; i < dim; ++i) + reassociation[i].push_back(i); + reassociation[std::max(0, dim - 1)].push_back(dim); + for (int i = dim, s = reassociation.size(); i < s; ++i) + reassociation[i].push_back(i + 1); + } + + valShape.push_back(0); + idxShape.push_back(0); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i + 1] = valShape[i]; + idxShape[i + 1] = idxShape[i]; + } + + valShape[dim] = 1; + idxShape[dim] = 1; + + Value unsqueezeVal = rewriter.create( + loc, valResultType, retVal, reassociation); + + Value unsqueezeIdx = rewriter.create( + loc, idxResultType, retIdx, reassociation); + + // Return unsqueezed. + llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; + rewriter.replaceOp(op, unsqueezes); + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -1621,12 +2490,20 @@ class ConvertTorchToTMTensor patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index f3ec5c01095f..76b9b87cbfe9 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -13,13 +13,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c7c52d08791f..5bf8a3387d88 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1,5 +1,5 @@ //===----------------------------------------------------------------------===// -// +//// // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -12,8 +12,8 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Matchers.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" @@ -22,9 +22,12 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/TypeSwitch.h" +#include +#include #include +#include using namespace mlir; using namespace mlir::torch; @@ -32,10 +35,10 @@ using namespace mlir::torch::Torch; namespace { -// These legalizations are for unary ops with only for floating point datatypes. -// There is no supported quantized integer mode for these. +// These legalizations are for unary ops with promoting input to floating-point +// datatypes only. There is no supported quantized integer mode for these. template -class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { +class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -49,17 +52,22 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - if (selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - self); - return success(); - } else { + auto resultTy = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + + if (!isa(resultTy.getElementType())) return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); - } + op, "Only floating-point datatype result types are supported"); + + // Non floating point inputs are not supported in TOSA so we cast the input + // to result type + if (!isa(selfTy.getElementType())) + self = tosa::promoteType(rewriter, self, resultTy); + + rewriter.replaceOpWithNewOp(op, resultTy, self); + + return success(); } }; @@ -73,11 +81,16 @@ class ConvertAtenUnaryOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, + auto self = adaptor.getSelf(); + + auto outType = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.getSelf()); + op.getType())); + + self = tosa::promoteType(rewriter, self, outType); + + rewriter.replaceOpWithNewOp(op, outType, self); + return success(); } }; @@ -101,13 +114,35 @@ class ConvertAtenBinaryOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto outTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhs).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); - auto binaryOp = - tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); - rewriter.replaceOp(op, binaryOp.getResult()); + auto outTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + + Value binaryOp; + + if constexpr (std::is_same()) { + // TOSA ArithmeticRightShiftOp has a round parameter. + binaryOp = rewriter.create(op->getLoc(), outTy, lhs, rhs, + /*round=*/false); + } else if constexpr (std::is_same() || + std::is_same()) { + lhs = tosa::promoteType(rewriter, lhs, outTy); + rhs = tosa::promoteType(rewriter, rhs, outTy); + // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum and + // tosa.minimum + binaryOp = rewriter.create( + op->getLoc(), outTy, lhs, rhs, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } else { + binaryOp = + tosa::createBinaryOpAndCast(rewriter, op, outTy, lhs, rhs); + } + + rewriter.replaceOp(op, binaryOp); return success(); } }; @@ -116,15 +151,14 @@ template static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt, const int64_t &intValue) { if (isFloat) { - // Do a round-trip check here instead of numeric limits due to - // compiler warnings around double <-> int conversion. - return (doubleValue == static_cast(static_cast(doubleValue))); - } else { - assert(isInt); + return (doubleValue >= + static_cast(std::numeric_limits::min())) && + (doubleValue <= static_cast(std::numeric_limits::max())); + } else if (isInt) { return (intValue >= static_cast(std::numeric_limits::min())) && (intValue <= static_cast(std::numeric_limits::max())); } - return true; + return false; } // FIXME: This will eventually go into a Tosa*Utils file. @@ -144,19 +178,25 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "Unable to extract the scalar constant"); + int64_t numElem = 1; + for (int64_t dim : dshape) + numElem *= dim; + if (isa(dtype)) { - tosaTensor = tosa::getConstTensor(rewriter, op, - (isFloat ? doubleValue : intValue), - dshape, dtype) - .value(); + tosaTensor = + tosa::getConstTensor( + rewriter, op, + SmallVector(numElem, (isFloat ? doubleValue : intValue)), + dshape, dtype) + .value(); } else if (auto intType = dyn_cast(dtype)) { - auto w = intType.getWidth(); - if (w != 1 && w != 32 && w != 64) + auto width = intType.getWidth(); + if (width != 1 && width != 8 && width != 32 && width != 64) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "Unsupported integer type: " << intType; }); - if (w == 1) { + if (width == 1) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " @@ -164,9 +204,21 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } bool d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); - } else if (w == 32) { + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); + } else if (width == 8) { + if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { + return rewriter.notifyMatchFailure( + op, "Supplied value of scalar constant exceeds limits " + "of destination type"); + } + int8_t d = isFloat ? static_cast(doubleValue) + : static_cast(intValue); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); + } else if (width == 32) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " @@ -174,17 +226,19 @@ LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter, } int32_t d = isFloat ? static_cast(doubleValue) : static_cast(intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); - } else if (w == 64) { + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); + } else if (width == 64) { if (!isInValidRange(isFloat, doubleValue, isInt, intValue)) { return rewriter.notifyMatchFailure( op, "Supplied value of scalar constant exceeds limits " "of destination type"); } int64_t d = (isFloat ? static_cast(doubleValue) : intValue); - tosaTensor = - tosa::getConstTensor(rewriter, op, {d}, dshape).value(); + tosaTensor = tosa::getConstTensor( + rewriter, op, SmallVector(numElem, d), dshape) + .value(); } } else { return rewriter.notifyMatchFailure(op, "Usupported element type"); @@ -250,9 +304,9 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } // Get output type: tensor - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) { @@ -260,6 +314,28 @@ class ConvertAtenAddSubOp : public OpConversionPattern { op, "Only floating-point or integer datatype legalization supported"); } + if (!rhsType) { + if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), rhs, + outElemTy, {}))) { + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } + rhsType = dyn_cast(rhs.getType()); + } + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhs).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + rhsType = dyn_cast(rhs.getType()); + + // aten.rsub(lhs, rhs, alpha) computes rhs - lhs * alpha + if constexpr (std::is_same::value) { + std::swap(lhs, rhs); + std::swap(lhsType, rhsType); + } + Type rhsAlphaMulElemType; if (isa(outElemTy)) { rhsAlphaMulElemType = outElemTy; @@ -268,25 +344,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern { rhsAlphaMulElemType = rewriter.getIntegerType(32); } - // if right is scalar, rhgType==None, which need to be manually cast to - // TensorType else right is tensor, rhsType==tensor - Value rhsAsTensor; - if (!rhsType) { - if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), - rhsAsTensor, rhsAlphaMulElemType, {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA operation"); - } else if (rhsType.getElementType() != rhsAlphaMulElemType) { + if (rhsType.getElementType() != rhsAlphaMulElemType) { // right is tensor, rhsType == tensor // right must be cast to same type as the alpha, so MulOp success - rhs = rewriter.create( - op->getLoc(), - RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs); - // reinitialize right value type to tensor - rhsType = dyn_cast(rhs.getType()); + rhsType = RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType); + rhs = rewriter.create(op->getLoc(), rhsType, rhs); } - auto rhsTensor = rhsType ? rhs : rhsAsTensor; // Handle scalar value alpha. // It should be either f32/i32 @@ -299,11 +362,13 @@ class ConvertAtenAddSubOp : public OpConversionPattern { op, "Currently only scalar constants are supported for " "alpha in conversion to TOSA operation"); } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, alphaTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); - auto mulAlphaOp = tosa::createMulOpAndCast( - rewriter, op, - rhsType ? rhsType : RankedTensorType::get({}, rhsAlphaMulElemType), - rhsTensor, alphaTensor, /*shift=*/0); + auto mulAlphaOp = tosa::createMulOpAndCast(rewriter, op, rhsType, rhs, + alphaTensor, /*shift=*/0); if (outElemTy.isInteger(64)) { // Tosa doesn't support 64-bit elementwise addition and subtraction. @@ -353,6 +418,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { // For bitwise operators, only integer datatype legalization is supported constexpr bool isBitwiseOp = std::is_same() || + std::is_same() || std::is_same() || std::is_same(); if (isa(lhsElemTy) && isBitwiseOp) { @@ -364,25 +430,64 @@ class ConvertAtenCompareOp : public OpConversionPattern { Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), - rhsAsTensor, lhsElemTy, {}))) + rhsAsTensor, rhs.getType(), {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA operation"); } + auto rhsTensor = rhsTy ? rhs : rhsAsTensor; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto rhsTensorTy = dyn_cast(rhsTensor.getType()); + auto rhsElemTy = rhsTensorTy.getElementType(); + // There is no Lesser operator in TOSA. constexpr auto swapLhsRhs = (std::is_same() || - std::is_same()); + std::is_same() || + std::is_same() || + std::is_same()); // Promote lhs and rhs dtypes for bitwise operators. - TensorType resultTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + TensorType resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (isBitwiseOp) { lhs = tosa::promoteType(rewriter, lhs, resultTy); rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); } + // Support different types comparisons + auto isLhsElemFloat = isa(lhsElemTy); + auto isRhsElemFloat = isa(rhsElemTy); + + if (lhsElemTy != rhsElemTy && !isBitwiseOp) { + if (isLhsElemFloat && !isRhsElemFloat) { + rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + } else if (!isLhsElemFloat && isRhsElemFloat) { + lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + } else if (isLhsElemFloat && isRhsElemFloat) { + auto lhsElemFloatTy = dyn_cast(lhsElemTy); + auto rhsElemFloatTy = dyn_cast(rhsElemTy); + if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) { + rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + } else { + lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + } + } else { + auto lhsElemIntTy = dyn_cast(lhsElemTy); + auto rhsElemIntTy = dyn_cast(rhsElemTy); + if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) { + rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); + } else { + lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); + } + } + } + auto resultOp = rewriter.create(op.getLoc(), resultTy, (swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor)); @@ -392,9 +497,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { std::is_same()) { rewriter.replaceOpWithNewOp(op, resultTy, resultOp.getResult()); - } - - else { + } else { rewriter.replaceOp(op, resultOp.getResult()); } @@ -418,9 +521,9 @@ class ConvertAtenMulOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) @@ -445,10 +548,15 @@ class ConvertAtenMulOp : public OpConversionPattern { rhsTensor = rhsType ? rhs : rhsAsTensor; } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + if (isa(outElemTy) || isa(outElemTy)) { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); auto mulOp = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsTensor, /*shift=*/0); @@ -463,6 +571,136 @@ class ConvertAtenMulOp : public OpConversionPattern { } }; +// Function to perform division with trunc rounding mode (rounding result +// towards zero) for float type inputs. +// This function takes in the division result between lhs and rhs rather +// than takes in the original lhs and rhs tensors as parameters. +std::optional truncFloatDivWithDivResult(PatternRewriter &rewriter, + Operation *op, + TensorType outType, + Value divResult) { + // To implement trunc mode for float inputs, multiply the floored abs + // of the tensor with the elementwise signedness of the tensor. + // div_result = lhs / rhs + // trunc_val = floor(abs(div_result)) * sign(div_result) + auto zero = + tosa::getConstTensor(rewriter, op, 0, {}, outType.getElementType()) + .value(); + + auto one = + tosa::getConstTensor(rewriter, op, 1, {}, outType.getElementType()) + .value(); + + auto minusOne = tosa::getConstTensor(rewriter, op, -1, {}, + outType.getElementType()) + .value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), divResult, one) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), divResult, zero) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), divResult, minusOne) + .failed()) + return std::nullopt; + + auto cond = rewriter.create( + op->getLoc(), + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), + divResult, zero); + + auto selectOp = rewriter.create(op->getLoc(), outType, cond, + one, minusOne); + + auto absDivResult = + rewriter.create(op->getLoc(), outType, divResult); + + auto flooredAbsDivResult = + rewriter.create(op->getLoc(), outType, absDivResult); + + Value result = + tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult, + selectOp, /*shift=*/0) + .getResult(); + + return result; +} + +// Function to perform division with trunc rounding mode (rounding result +// towards zero) for float type inputs +Value truncFloatDiv(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs) { + rhs = tosa::promoteType(rewriter, rhs, outType); + + auto rhsRcp = + rewriter.create(op->getLoc(), rhs.getType(), rhs); + + auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp, + /*shift=*/0); + + return truncFloatDivWithDivResult(rewriter, op, outType, divResult).value(); +} + +// Function to perform division with floor rounding mode (rounding result +// down) for integer type inputs. +std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, + TensorType outType, Value lhs, Value rhs) { + // To implement floor mode int input, utilize tosa::IntDivOp (trunc div + // result) with the following formula elementwise: + // floor_val = trunc_val - ((trunc_val * rhs != lhs) + // && (sign(lhs) != sign(rhs))) + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhs).failed()) + return std::nullopt; + + // TOSA IntDiv requires inputs to be i32 + auto i32Type = + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(32)); + lhs = tosa::promoteType(rewriter, lhs, i32Type); + rhs = tosa::promoteType(rewriter, rhs, i32Type); + + auto intDivOp = + rewriter.create(op->getLoc(), i32Type, lhs, rhs); + + auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); + + auto one = tosa::getConstTensor(rewriter, op, 1, {}).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, one).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, zero).failed()) + return std::nullopt; + + auto boolType = + RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)); + + auto lhsMulRhs = rewriter.create(op->getLoc(), i32Type, lhs, rhs, + /*shift=*/0); + + auto lhsRhsDifferentSign = + rewriter.create(op->getLoc(), boolType, zero, lhsMulRhs); + + auto truncMulRhs = rewriter.create(op->getLoc(), i32Type, + intDivOp, rhs, /*shift=*/0); + + auto truncMulRhsEqualLhs = + rewriter.create(op->getLoc(), boolType, truncMulRhs, lhs); + + auto truncMulRhsNotEqualLhs = rewriter.create( + op->getLoc(), boolType, truncMulRhsEqualLhs); + + auto truncMinusOne = + rewriter.create(op->getLoc(), i32Type, intDivOp, one); + + auto cond = rewriter.create( + op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs); + + auto selectOp = rewriter.create(op->getLoc(), i32Type, cond, + truncMinusOne, intDivOp); + + Value result = tosa::promoteType(rewriter, selectOp, outType); + + return result; +} + template class ConvertAtenDivOp : public OpConversionPattern { public: @@ -494,29 +732,74 @@ class ConvertAtenDivOp : public OpConversionPattern { "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto outType = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + + // Get rounding mode for aten.div.Tensor_mode + std::string roundMode; + if constexpr (std::is_same() || + std::is_same()) { + if (!matchPattern(op.getRoundingMode(), m_TorchConstantStr(roundMode))) + return rewriter.notifyMatchFailure( + op, "Non-const rounding mode parameter unsupported"); + } - // auto result; Value result; if (isa(outType.getElementType())) { - // The input to the reciprocal is an integer sometimes, and we may need to - // promote it to a floating point. Per TOSA specification, the input types - // can only be floating point for tosa::ReciprocalOp. - Value rhsCasted = tosa::promoteType(rewriter, rhsTensor, outType); - auto rcpOp = rewriter.create( - op->getLoc(), rhsCasted.getType(), rhsCasted); - - result = tosa::createMulOpAndCast(rewriter, op, outType, lhs, - rcpOp.getResult(), /*shift=*/0) - .getResult(); + // The input to the reciprocal is an integer sometimes, and we may need + // to promote it to a floating point. Per TOSA specification, the input + // types can only be floating point for tosa::ReciprocalOp. + rhsTensor = tosa::promoteType(rewriter, rhsTensor, outType); + auto rhsRcp = rewriter.create( + op->getLoc(), rhsTensor.getType(), rhsTensor); + + auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, + rhsRcp, /*shift=*/0); + + // Round result based on rounding mode + if (roundMode.compare("floor") == 0) { + // "floor": rounds the results of the division down. Equivalent to + // floor division in Python (the // operator). + auto floorOp = + rewriter.create(op->getLoc(), outType, divResult); + + result = floorOp.getResult(); + } else if (roundMode.compare("trunc") == 0) { + // "trunc": rounds the results of the division towards zero. Equivalent + // to C-style integer division. + result = truncFloatDivWithDivResult(rewriter, op, outType, divResult) + .value(); + } else { + // None: No rounding mode + result = divResult.getResult(); + } } else { - // The output type can be different than the input types (e.g. dividing an - // int tensor results in a floating point tensor). - result = tosa::createBinaryOpAndCast( - rewriter, op, outType, lhs, rhsTensor) - .getResult(); + if (roundMode.compare("floor") == 0) { + // "floor": rounds the results of the division down. Equivalent to floor + // division in Python (the // operator). + result = floorIntDiv(rewriter, op, outType, lhs, rhsTensor).value(); + } else { + // "trunc": rounds the results of the division towards zero. Equivalent + // to C-style integer division. + // None: no rounding mode. + + // TOSA IntDiv requires inputs to be i32 + auto i32Type = RankedTensorType::get(outType.getShape(), + rewriter.getIntegerType(32)); + lhs = tosa::promoteType(rewriter, lhs, i32Type); + rhsTensor = tosa::promoteType(rewriter, rhsTensor, i32Type); + + auto intDivOp = rewriter.create(op->getLoc(), i32Type, + lhs, rhsTensor); + + result = tosa::promoteType(rewriter, intDivOp, outType); + } } rewriter.replaceOp(op, {result}); @@ -536,39 +819,37 @@ class ConvertAtenOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override; }; -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenTanhOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); - if (selfTy && selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); - return success(); - } - // Sigmoid legalization in TOSA for quantized element-type uses specialized - // tosa.table construct. - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); -} +template +class ConvertAtenActivationFunctionOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value self = adaptor.getSelf(); + auto selfTy = dyn_cast(self.getType()); + + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only Tensor types supported"); + + auto resultTy = dyn_cast( + this->getTypeConverter()->convertType(op.getType())); + + if (!isa(resultTy.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // Non floating point inputs are not supported for activation functions + // (erf, sigmoid, tanh) in TOSA so we cast the input to result type + if (!isa(selfTy.getElementType())) + self = tosa::promoteType(rewriter, self, resultTy); + + rewriter.replaceOpWithNewOp(op, resultTy, self); -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSigmoidOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); - if (selfTy && selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self); return success(); } - // Sigmoid legalization in TOSA for quantized element-type uses - // specialized tosa.table construct. - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization currently supported"); -} +}; template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -586,16 +867,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // Rescale the clampIn for quantized types. TBD - if (!selfTy.getElementType().isa()) { + if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); } + + // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), clampIn, rewriter.getI64IntegerAttr(clampMin), rewriter.getI64IntegerAttr(std::numeric_limits::max()), rewriter.getF32FloatAttr(0.0f), - rewriter.getF32FloatAttr(std::numeric_limits::max())); + rewriter.getF32FloatAttr(std::numeric_limits::max()), + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); return success(); } @@ -606,7 +890,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value self = adaptor.getSelf(); auto selfTy = cast(self.getType()); - if (!selfTy.getElementType().isa()) { + if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization currently supported"); } @@ -618,10 +902,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Negative slope needs to be a scalar constant for conversion to " "TOSA LeakyReLU operation"); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), alphaTensor, self) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); auto zero = tosa::getConstTensor(rewriter, op, 0, {}, selfTy.getElementType()) .value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), zero, self).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto cond = rewriter.create( op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), @@ -669,13 +961,60 @@ class ConvertAtenReductionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto outputTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outputTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outputTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor type outputs permitted for reduce_mean"); + auto selfElemTy = selfTy.getElementType(); + if (!selfElemTy.isIntOrFloat()) + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); + + // TOSA ReduceAll and ReduceAny ops only accept bool input + if constexpr (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { + self = tosa::promoteType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1))); + } + + // Handle dtype output and bool elem type for ReduceSum and ReduceProd ops + if constexpr (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { + auto dtype = op.getDtype(); + int64_t dtypeInt; + if (!isa(dtype.getType())) { + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure(op, "dtype is not a constant int"); + + FailureOr maybeDtypeType = getTypeForScalarType( + op.getContext(), (torch_upstream::ScalarType)dtypeInt); + if (failed(maybeDtypeType)) { + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + } else { + Type dtypeType = maybeDtypeType.value(); + + if (isa(dtypeType)) + dtypeType = + rewriter.getIntegerType(dtypeType.getIntOrFloatBitWidth()); + + self = tosa::promoteType( + rewriter, self, + RankedTensorType::get(selfTy.getShape(), dtypeType)); + } + } else { + if (selfElemTy.isInteger(1)) + self = tosa::promoteType(rewriter, self, outputTy); + } + } + ElementsAttr reduceDimsAttr; bool keepDims; @@ -689,8 +1028,6 @@ class ConvertAtenReductionOp : public OpConversionPattern { if (!result) return failure(); - // TBD - support dtype casting. - rewriter.replaceOp(op, {result.value()}); return success(); @@ -710,13 +1047,17 @@ class ConvertAtenMultipleDimsReductionOp ConversionPatternRewriter &rewriter, ElementsAttr &reduceDimsAttr, bool &keepDims) const override { - SmallVector reduceDims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) - return rewriter.notifyMatchFailure(op, - "non-const dim parameter unsupported"); - int64_t N = reduceDims.size(); int64_t inputRank = cast(adaptor.getSelf().getType()).getRank(); + + SmallVector reduceDims; + // If dim list is none, all dimensions are reduced + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(reduceDims))) { + for (int64_t i = 0; i < inputRank; i++) + reduceDims.push_back(i); + } + + int64_t N = reduceDims.size(); for (unsigned i = 0; i < N; i++) { reduceDims[i] = toPositiveDim(reduceDims[i], inputRank); if (!isValidDim(reduceDims[i], inputRank)) @@ -830,9 +1171,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "non-const keepdim parameter unsupported"); - auto resultTy = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto resultTy = cast( + getTypeConverter()->convertType(op.getResult().getType())); auto outputETy = resultTy.getElementType(); // Create a single instance of tosa.argmax. @@ -858,10 +1198,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI32Type()); auto reduceDimAttr = rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim); + + // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax return rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(outputReduceTy), - input, reduceDimAttr) + .create( + op->getLoc(), getTypeConverter()->convertType(outputReduceTy), + input, reduceDimAttr, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) .getResult(); }; @@ -929,9 +1272,9 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Squeeze could not compute new shape"); - auto resultTy = OpConversionPattern::getTypeConverter() - ->convertType(op.getResult().getType()) - .template cast(); + auto resultTy = cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getResult().getType())); auto resultElemTy = resultTy.getElementType(); auto newOutputTy = RankedTensorType::get( @@ -1007,40 +1350,167 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { } }; +template +class ConvertAtenPowOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto outType = + cast(this->getTypeConverter()->convertType(op.getType())); + + if (!isa(outType.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + Value selfTensor; + if constexpr (std::is_same()) { + Value selfScalar = op.getSelf(); + if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor, + outType.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA PowScalar operation"); + } else { + selfTensor = adaptor.getSelf(); + auto selfTy = cast(selfTensor.getType()); + + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); + + // Non floating point inputs are not supported for tosa.pow so we cast the + // input to result type + if (!isa(selfTy.getElementType())) + selfTensor = tosa::promoteType(rewriter, selfTensor, outType); + } + + Value expTensor; + if constexpr (std::is_same()) { + Value expScalar = op.getExponent(); + if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, + outType.getElementType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); + } else { + expTensor = adaptor.getExponent(); + auto expTy = cast(expTensor.getType()); + + if (!expTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Pow"); + + // Non floating point exponents are not supported for tosa.pow so we cast + // the exponent to result type + if (!isa(expTy.getElementType())) + expTensor = tosa::promoteType(rewriter, expTensor, outType); + } + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), selfTensor, expTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto powOp = tosa::createBinaryOpAndCast( + rewriter, op, outType, selfTensor, expTensor); + rewriter.replaceOp(op, powOp.getResult()); + + return success(); + } +}; + template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenPowTensorScalarOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowTensorTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + auto selfTy = dyn_cast(self.getType()); + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + Value expTensor = adaptor.getExponent(); + auto expTensorTy = dyn_cast(expTensor.getType()); - if (!selfTy) + if (!selfTy || !outType || !expTensorTy) { return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); + } - if (!selfTy.getElementType().isa()) + if (!isa(selfTy.getElementType())) { return rewriter.notifyMatchFailure( op, "Only floating-point datatype legalization supported"); + } - auto outType = - cast(getTypeConverter()->convertType(op.getType())); - - Value expTensor; - Value expScalar = op.getExponent(); - if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor, - outType.getElementType(), {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA Pow operation"); + if (expTensorTy.getElementType() != selfTy.getElementType()) { + expTensor = rewriter.createOrFold( + op->getLoc(), + RankedTensorType::get(expTensorTy.getShape(), selfTy.getElementType()), + expTensor); + } auto powOp = tosa::createBinaryOpAndCast(rewriter, op, outType, self, expTensor); rewriter.replaceOp(op, powOp.getResult()); - return success(); } +Type getMatMulOutputType(Type inputElemTy, Type outputElemTy, + PatternRewriter &rewriter) { + Type tosaOutputElemTy; + if (auto floatTy = dyn_cast(inputElemTy)) { + if (inputElemTy.isF16() && outputElemTy.isF16()) { + return rewriter.getF16Type(); + } + if (floatTy.isBF16() || floatTy.isF16() || floatTy.isF32()) { + // Always accumulate on f32 + tosaOutputElemTy = rewriter.getF32Type(); + } + } else if (auto integerTy = dyn_cast(inputElemTy)) { + if (integerTy.isInteger(/*width=*/8)) { + tosaOutputElemTy = rewriter.getIntegerType(/*width=*/32); + } else if (integerTy.isInteger(/*width=*/16)) { + tosaOutputElemTy = rewriter.getIntegerType(/*width=*/48); + } + } + return tosaOutputElemTy; +} + +RankedTensorType getCastedInputTypeForMatmul(Value inputValue, + PatternRewriter &rewriter) { + // Check to see if the inputs to the matmul where casted from another type + auto preCastType = + TypeSwitch(inputValue.getDefiningOp()) + .Case([](tosa::CastOp op) { + return cast(op->getOperand(0).getType()); + }) + .Default([](Operation * /*op*/) { return RankedTensorType(); }); + if (!preCastType) { + return preCastType; + } + Type castOutputTy = + cast(inputValue.getType()).getElementType(); + // The FxImporter does not support si48 and neither does torch-mlir so for now + // we reject this case for the future when the dialect and importer may + // support it. + if (castOutputTy.isInteger(48) && + (castOutputTy.isSignedInteger() || castOutputTy.isSignlessInteger())) { + return RankedTensorType(); + } + // Calculate the expected accumulator type based on the input type of the cast + auto accumulatorType = + getMatMulOutputType(preCastType.getElementType(), castOutputTy, rewriter); + // If the expected accumulatorType for the given input type of the + // cast matches the output type of the cast then we can fold the + // casting into the matmul. The tosa matmul is defined to cast the + // inputs to the output type first, so we do not need explicit + // casts up front. + return accumulatorType == castOutputTy ? preCastType : RankedTensorType(); +} + // Perform the basic n-dim matmul operation encompassing the handling of // broadcasting and dynamic shape propagation. // All PyTorch ops that leverage matrix multiplication will derive this and @@ -1082,6 +1552,39 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Matmul: input datatypes mismatched"); + // Step: check if the inputs have been casted from a supported input type to + // an accumulator type and insert casts back to the original type if true + RankedTensorType lhsPreCastedType = + getCastedInputTypeForMatmul(lhs, rewriter); + RankedTensorType rhsPreCastedType = + getCastedInputTypeForMatmul(rhs, rewriter); + if (lhsPreCastedType && rhsPreCastedType && + (lhsPreCastedType.getElementType() == + rhsPreCastedType.getElementType())) { + lhs = rewriter.create( + lhs.getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + lhsPreCastedType), + lhs); + rhs = rewriter.create( + rhs.getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + rhsPreCastedType), + rhs); + lhsElemTy = cast(lhsPreCastedType).getElementType(); + rhsElemTy = cast(rhsPreCastedType).getElementType(); + } + + auto torchMatmulOutputType = + cast(op.getType()).getDtype(); + auto outputElemTy = + getMatMulOutputType(lhsElemTy, torchMatmulOutputType, rewriter); + if (!outputElemTy) { + return rewriter.notifyMatchFailure( + op, "Only i8 and i16 integer and bf16, f16 and " + "f32 float types are valid"); + } + // Legalization constructs may offer input shapes but expect output shapes // to be inferred, e.g. // func @forward(%arg0: !torch.vtensor<[14,19],f32>, @@ -1455,12 +1958,6 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { SmallVector matmulOutputShape( {matmulLhsShape[0], matmulLhsShape[1], matmulRhsShape[2]}); - Type outputElemTy; - if (isa(lhsElemTy)) { - outputElemTy = lhsElemTy; - } else { // qint8 emits i32 matmul output - outputElemTy = rewriter.getIntegerType(32); - } auto mmOutputTy = RankedTensorType::get( makeShapeLLVMCompatible(matmulOutputShape), outputElemTy); @@ -1473,8 +1970,17 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { matmulLhs, matmulRhs) .getResult(); - // Perform the reshape to output shape. This is always required unless max - // input rank=3 and there was no broadcasting, in which case the tosa.matmul + auto torchOpOutputType = lhsTy.getElementType(); + auto castOutputTy = RankedTensorType::get( + makeShapeLLVMCompatible(matmulOutputShape), torchOpOutputType); + auto castResult = rewriter.createOrFold( + op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + castOutputTy), + mmOpResult); + + // Perform the reshape to output shape. This is always required unless max + // input rank=3 and there was no broadcasting, in which case the tosa.matmul // output itself is correctly shaped. bool performOpReshape = !(maxInputRank == 3 && !performBatchDimBroadcast); @@ -1573,12 +2079,12 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Perform reshape auto reshapedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(reshapedOpShape), outputElemTy); + makeShapeLLVMCompatible(reshapedOpShape), torchOpOutputType); auto reshapedOp = rewriter.create( op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), - mmOpResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); + castResult, rewriter.getDenseI64ArrayAttr(reshapedOpShape)); if (opNeedsTranspose) { @@ -1589,7 +2095,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { /*shape=*/{static_cast(transposedOpDims.size())}); auto transposedOpType = RankedTensorType::get( - makeShapeLLVMCompatible(transposedOpShape), outputElemTy); + makeShapeLLVMCompatible(transposedOpShape), torchOpOutputType); output = rewriter .create( op->getLoc(), @@ -1602,7 +2108,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { output = reshapedOp.getResult(); } } else { - output = mmOpResult; + output = castResult; } return success(); @@ -1624,13 +2130,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Failed to perform matmul operation"); - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(), - output); - + rewriter.replaceOp(op, output); return success(); } }; @@ -1746,6 +2246,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto bias = adaptor.getBias(); auto biasTy = bias.getType(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), lhs, bias).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + // TOSA does not mandate that elementwise op tensors need to be ranked. if (!isa(biasTy) && !isa(biasTy)) return rewriter.notifyMatchFailure( @@ -1799,14 +2303,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { matmulOutput, bias) .getResult(); } - - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template cast(), - matmulPlusBias); - + rewriter.replaceOp(op, matmulPlusBias); return success(); } }; @@ -1825,34 +2322,93 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Rsub"); - if (!selfTy.getElementType().isa()) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype legalization supported"); + auto resultTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto resultElemTy = resultTy.getElementType(); + + self = tosa::promoteType(rewriter, self, resultTy); Value otherTensor, alphaTensor; if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor, - selfTy.getElementType(), {}))) + resultElemTy, {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA Rsub operation"); if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar, - alphaTensor, selfTy.getElementType(), + alphaTensor, resultElemTy, /*checkForUnity=*/true))) return failure(); - auto multTensor = rewriter.create( - op->getLoc(), getTypeConverter()->convertType(op.getType()), self, - alphaTensor, /*shift=*/0); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, otherTensor) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, alphaTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), otherTensor, - multTensor); + auto multTensor = rewriter.create(op->getLoc(), resultTy, self, + alphaTensor, /*shift=*/0); + + rewriter.replaceOpWithNewOp(op, resultTy, otherTensor, + multTensor); return success(); } +/// tosa.conv2d does not support group convolution. +/// Therefore, we create multiple ops where the input, kernel +/// and bias are slices of the original inputs. +/// Afterwards we concat the results into a single tensor. +/// This is inspired by the legalization done in onnx-mlir. +Value createConvInGroups(PatternRewriter &rewriter, Operation *op, + Type &resultType, + const llvm::ArrayRef weightShape, + Value &input, Value &weights, Value &bias, + const int64_t groups, DenseI64ArrayAttr pads, + DenseI64ArrayAttr strides, DenseI64ArrayAttr dilations, + TypeAttr accType) { + // Set up constants outside of loop + const int64_t sizeOfSliceInput = weightShape[1]; + const int64_t sizeOfSliceKernel = weightShape[0] / groups; + auto inputShape = cast(input.getType()).getShape(); + + llvm::SmallVector inputSize = {inputShape[0], inputShape[1], + inputShape[2], sizeOfSliceInput}; + llvm::SmallVector kernelSize = {sizeOfSliceKernel, weightShape[2], + weightShape[3], weightShape[1]}; + llvm::SmallVector sliceValues; + Type outputType = RankedTensorType::get( + llvm::SmallVector(4, ShapedType::kDynamic), + cast(resultType).getElementType()); + for (int64_t i = 0; i < groups; i++) { + // Slice input + Value sliceInput = tosa::buildSlice( + rewriter, input, {0, 0, 0, i * sizeOfSliceInput}, inputSize); + + // Slice kernel + Value sliceWeight = tosa::buildSlice( + rewriter, weights, {i * sizeOfSliceKernel, 0, 0, 0}, kernelSize); + + // Slice bias + Value sliceBias = tosa::buildSlice(rewriter, bias, {i * sizeOfSliceKernel}, + {sizeOfSliceKernel}); + + // Create conv + Value tempConv2D = tosa::CreateOpAndInfer( + rewriter, input.getLoc(), outputType, sliceInput, sliceWeight, + sliceBias, pads, strides, dilations, accType); + // Add value to vector + sliceValues.push_back(tempConv2D); + } + + constexpr int64_t channelDim = 3; + // Create concat op + return tosa::CreateOpAndInfer( + rewriter, op->getLoc(), outputType, sliceValues, channelDim); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenConvolutionOp op, OpAdaptor adaptor, @@ -1871,9 +2427,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputTy = cast(input.getType()); auto weightTy = cast(weight.getType()); - auto outputTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outputTy = + cast(getTypeConverter()->convertType(op.getType())); if (!inputTy || !weightTy || !outputTy) return rewriter.notifyMatchFailure( @@ -1907,7 +2462,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { SmallVector zeroVec(weightShape[0], 0); bias = tosa::getConstTensor(rewriter, op, zeroVec, - {static_cast(weightShape[0])}) + {static_cast(weightShape[0])}, + inputElemTy) .value(); } } else { @@ -1921,10 +2477,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t groups; if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) { return rewriter.notifyMatchFailure(op, "non-const group size unsupported"); - } else if (groups != 1 && weightShape[1] != 1) { - return rewriter.notifyMatchFailure( - op, "group size must be 1 (convolution) or weight.dim(1) must be 1 " - "(depthwise convolution)"); } SmallVector stride; @@ -1936,6 +2488,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( m_TorchListOfConstantInts(padding_2d))) return rewriter.notifyMatchFailure(op, "non-const padding list unsupported"); + + if (padding_2d.size() != 2) { + // pytorch 2.5 generates one element padding = {0} for + // Conv2dWithValidPaddingModule + return rewriter.notifyMatchFailure(op, "unexpected number of paddings"); + } + // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}. // The Torch OFM computation uses 2*pad in each spatial direction, implying // the same t=b and l=r values for TOSA. @@ -1947,6 +2506,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); + TypeAttr accType; + if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy, + accType))) + return rewriter.notifyMatchFailure( + op, "failed to get accumulator type for convolution ops"); + // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. // Perform the necessary transformations. std::optional nchwToNhwcTransposeConst = @@ -1969,8 +2534,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType transformedWeightType; Value transformedWeight; int64_t outputCDim; - if (groups == 1) { - // full convolution: O(I/G)HW-> OHWI + if (groups == 1 || weightShape[1] != 1) { + // full (group) convolution: O(I/G)HW-> OHWI transformedWeightShape = {weightShape[0], weightShape[2], weightShape[3], weightShape[1]}; transformedWeightType = RankedTensorType::get( @@ -1983,7 +2548,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( nchwToNhwcTransposeConst.value()) .getResult(); outputCDim = transformedWeightShape[0]; - } else if (weightShape[1] == 1) { + } else { // depthwise convolution: O(I/G)HW-> HWIM) // transpose: O(I/G)HW -> HWO(I/G) std::optional transposeConst = @@ -2025,8 +2590,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedWeight, rewriter.getDenseI64ArrayAttr(transformedWeightShape)) .getResult(); - } else { - llvm_unreachable("Unhandled convolution type"); } int64_t outputHDim, outputWDim; @@ -2052,23 +2615,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // quantized input is i32, which gets rescaled down to quantized output range. SmallVector outputShape = {transposedInputShape[0], outputHDim, outputWDim, outputCDim}; - auto convOpTy = - RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); - Value convOpResult; if (groups == 1) { // full convolution + auto convOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); convOpResult = rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else if (weightShape[1] == 1) { // depthwise convolution + auto convOpTy = + RankedTensorType::get(makeShapeLLVMCompatible(outputShape), biasElemTy); convOpResult = rewriter .create( @@ -2076,10 +2640,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedInput, transformedWeight, bias, rewriter.getDenseI64ArrayAttr(padding), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else { - llvm_unreachable("Unhandled convolution type"); + // general group convolution + convOpResult = createConvInGroups( + rewriter, op, outputTy, weightShape, transposedInput, transformedWeight, + bias, groups, rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType); } std::optional nhwcToNchwTransposeConst = @@ -2146,9 +2715,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter, - Type outType, Value input, Value variance, Value eps, - Value mean, Value weight, Value bias) { +std::optional computeBatchNorm(Operation *op, + ConversionPatternRewriter &rewriter, + Type outType, Value input, Value variance, + Value eps, Value mean, Value weight, + Value bias) { // For PyTorch: // scale = gamma = weight // offset = beta = bias @@ -2172,6 +2743,15 @@ Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter, // op5 = mul(op4, bscale) // op6 = add(op5, boffset) + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, mean).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, variance) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, eps).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, weight) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, bias).failed()) + return std::nullopt; + auto op1SubInputMean = rewriter.create(op->getLoc(), outType, input, mean); @@ -2210,7 +2790,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Note: cudnn_enabled is not handled. // FIXME: Handle training and momentum. - if (op.getMomentum().getType().isa()) + if (isa(op.getMomentum().getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for momentum"); auto meanType = dyn_cast(adaptor.getRunningMean().getType()); @@ -2280,7 +2860,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto batchNorm = computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, - epsilonConst, meanVal, weightVal, biasVal); + epsilonConst, meanVal, weightVal, biasVal) + .value(); rewriter.replaceOp(op, {batchNorm}); @@ -2300,11 +2881,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // eventually being reshaped for broadcasting. // Not a ranked tensor output - if (!dyn_cast(adaptor.getInput().getType())) + auto input = adaptor.getInput(); + auto inputType = dyn_cast(input.getType()); + + if (!inputType) return rewriter.notifyMatchFailure( op, "Only ranked tensor types are supported"); - auto inputType = cast(adaptor.getInput().getType()); if (inputType.getRank() > 4) return rewriter.notifyMatchFailure(op, "Only up to 4D tensors are supported"); @@ -2314,13 +2897,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Note: cudnn_enabled is not handled. // FIXME: Handle the None cases for the optional parameters. - if (adaptor.getWeight().getType().isa()) + auto weight = adaptor.getWeight(); + if (isa(weight.getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for weight"); - if (adaptor.getBias().getType().isa()) + + auto bias = adaptor.getBias(); + if (isa(bias.getType())) return rewriter.notifyMatchFailure(op, "Unsupported None for bias"); - auto weightType = cast(adaptor.getWeight().getType()); - auto biasType = cast(adaptor.getBias().getType()); + auto weightType = cast(weight.getType()); + auto biasType = cast(bias.getType()); int64_t inputRank = inputType.getRank(); Type elemTy = inputType.getElementType(); SmallVector inputTypeShape( @@ -2385,6 +2971,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value elemCntRcp = rewriter.create( op.getLoc(), elemCntConst.getType(), elemCntConst); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, elemCntRcp) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + // Broadcast type and shape for various intermediate values. SmallVector bcastOutShape; for (auto en : llvm::enumerate(inputTypeShape)) { @@ -2396,14 +2987,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(makeShapeLLVMCompatible(bcastOutShape), elemTy); // Compute mean. - Value sum = computeSumAndReshape(adaptor.getInput(), inputType, bcastOutType, - bcastOutShape); + Value sum = + computeSumAndReshape(input, inputType, bcastOutType, bcastOutShape); Value meanVal = rewriter.create(op.getLoc(), bcastOutType, sum, elemCntRcp, /*shift=*/0); // Compute variance. - Value squareSumSub = rewriter.create( - op.getLoc(), inputType, adaptor.getInput(), meanVal); + Value squareSumSub = + rewriter.create(op.getLoc(), inputType, input, meanVal); Value squareSum = rewriter.create(op.getLoc(), inputType, squareSumSub, squareSumSub, 0); @@ -2424,11 +3015,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( makeShapeLLVMCompatible(weightAndBiasBcastShape), elemTy); Value weightVal = rewriter.create( - op.getLoc(), weightAndMeanBcastType, adaptor.getWeight(), + op.getLoc(), weightAndMeanBcastType, weight, rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); Value biasVal = rewriter.create( - op.getLoc(), weightAndMeanBcastType, adaptor.getBias(), + op.getLoc(), weightAndMeanBcastType, bias, rewriter.getDenseI64ArrayAttr(weightAndBiasBcastShape)); double eps; @@ -2440,9 +3031,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); // Compute layer norm. - auto layerNorm = - computeBatchNorm(op, rewriter, outType, adaptor.getInput(), varianceVal, - epsilonConst, meanVal, weightVal, biasVal); + auto layerNorm = computeBatchNorm(op, rewriter, outType, input, varianceVal, + epsilonConst, meanVal, weightVal, biasVal) + .value(); rewriter.replaceOp(op, {layerNorm, meanVal, varianceVal}); @@ -2455,9 +3046,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ValueTensorLiteralOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto outputTy = getTypeConverter() - ->convertType(op.getType()) - .template cast(); + auto outputTy = + cast(getTypeConverter()->convertType(op.getType())); // Tensors with integer types need to be converted to signless integer // element type. All tensors with element types other than integer can reuse @@ -2610,21 +3200,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only ranked tensor types with static shapes are currently supported"); - SmallVector dimListInt; - if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dimListInt))) + SmallVector dimListInt64; + if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dimListInt64))) return rewriter.notifyMatchFailure( op, "Only constant dimensions are currently supported"); + SmallVector dimListInt32; + copy(dimListInt64, std::back_inserter(dimListInt32)); int64_t selfRank = selfType.getRank(); // TODO: If this is already verified on the op then we can drop checking here. - for (auto &d : dimListInt) { + for (auto &d : dimListInt32) { d = toPositiveDim(d, selfRank); if (!isValidDim(d, selfRank)) return rewriter.notifyMatchFailure(op, "Not all dims are valid"); } - auto transposeDimsConst = mlir::tosa::getConstTensor( - rewriter, op.getOperation(), dimListInt, {selfRank}); + auto transposeDimsConst = mlir::tosa::getConstTensor( + rewriter, op.getOperation(), dimListInt32, {selfRank}); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), @@ -2637,24 +3229,36 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenLog2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + + // If input is not a float type then cast it to output type + auto selfElemTy = selfType.getElementType(); + if (!isa(selfElemTy)) + self = tosa::promoteType(rewriter, self, outType); + // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056f}, - ln2Shape, selfType.getElementType()) + ln2Shape, outType.getElementType()) .value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, ln2Op).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); - auto outType = getTypeConverter()->convertType(op.getType()); - auto logOp = - rewriter.create(op.getLoc(), outType, adaptor.getSelf()); + auto logOp = rewriter.create(op.getLoc(), outType, self); rewriter.replaceOpWithNewOp(op, outType, logOp, rcpOp, /*shift=*/0); @@ -2665,9 +3269,10 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenThresholdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2677,12 +3282,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only floating-point or integer datatype legalization supported"); - // Integer types with width > 32 are not supported - auto selfIntType = dyn_cast(selfElemTy); - if (selfIntType && selfIntType.getWidth() > 32) { - return rewriter.notifyMatchFailure( - op, "Integer types with width greater than 32 are not supported"); - } + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto outElemTy = outType.getElementType(); SmallVector constTypeShape(selfType.getRank(), 1); Value threshold, value; @@ -2692,21 +3294,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only scalar constant is supported for threshold"); if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), value, - selfElemTy, constTypeShape))) + outElemTy, constTypeShape))) return rewriter.notifyMatchFailure( op, "Only scalar constant is supported for value"); - // Threshold only clamps the upper values. tosa::ClampOp has the same - // value for both threshold and clamped value so cannot be used. - auto outType = getTypeConverter()->convertType(op.getType()); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, threshold) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, value).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); auto cmpOp = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - adaptor.getSelf(), threshold); + self, threshold); - rewriter.replaceOpWithNewOp(op, outType, cmpOp, - adaptor.getSelf(), value); + rewriter.replaceOpWithNewOp(op, outType, cmpOp, self, value); return success(); } @@ -2862,8 +3465,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -static Value approximateErfOp(ConversionPatternRewriter &rewriter, - Operation *op, Value x, Type dtype) { +static std::optional +approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x, + Type dtype) { // Using: // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = @@ -2876,26 +3480,34 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, auto absX = rewriter.create(loc, outType, x); auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); + auto a2 = + tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); + auto a3 = + tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); + auto a4 = + tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a1).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a2).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a3).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed()) + return std::nullopt; + auto a1X = rewriter.create(loc, outType, a1, absX, /*shift=*/0); auto sum = rewriter.create(loc, outType, a1X, one); - auto a2 = - tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); auto x2 = rewriter.create(loc, outType, absX, absX, /*shift=*/0); auto a2X = rewriter.create(loc, outType, a2, x2, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a2X); - auto a3 = - tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); auto x3 = rewriter.create(loc, outType, x2, absX, /*shift=*/0); auto a3X = rewriter.create(loc, outType, a3, x3, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a3X); - auto a4 = - tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); auto x4 = rewriter.create(loc, outType, x3, absX, /*shift=*/0); auto a4X = rewriter.create(loc, outType, a4, x4, /*shift=*/0); sum = rewriter.create(loc, outType, sum, a4X); @@ -2917,10 +3529,22 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter, return rewriter.create(loc, outType, cond, erf, negateErf); } -static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, - Operation *op, Value x, Type dtype) { +static std::optional +buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, + Type dtype) { auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); + // rsqrt of 2 + auto rsqrt2 = + tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, oneHalf).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, rsqrt2).failed()) + return std::nullopt; auto loc = op->getLoc(); @@ -2928,16 +3552,11 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter, auto outType = x.getType(); auto mean = zero; Value xMinusMean = rewriter.create(loc, outType, x, mean); - // rsqrt of 2 - Value rsqrt2 = - tosa::getConstTensor(rewriter, op, 0.70710678f, {}, dtype).value(); Value erfArg = rewriter.create(loc, outType, xMinusMean, rsqrt2, /*shift=*/0); - Value erf = approximateErfOp(rewriter, op, erfArg, dtype); + Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value(); Value erfPlus1 = rewriter.create(loc, outType, one, erf); - Value oneHalf = - tosa::getConstTensor(rewriter, op, 0.5, {}, dtype).value(); Value normalCdf = rewriter.create(loc, outType, oneHalf, erfPlus1, /*shift=*/0); @@ -2949,9 +3568,10 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -2962,21 +3582,108 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point datatype legalization supported"); } - // TODO: Handle approximate. + auto resultType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + std::string approximate; - if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) || - approximate != "none") { - return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); + if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) { + return rewriter.notifyMatchFailure( + op, "Non-const approximate value not supported"); } - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); - cdf = rewriter.createOrFold( - op->getLoc(), - cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + if (approximate.compare("none") == 0) { + // GELU(x) = x * CDF(x) + Value cdf = + buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy).value(); + cdf = rewriter.createOrFold( + op->getLoc(), + cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); + + rewriter.replaceOpWithNewOp(op, resultType, self, cdf, + /*shift=*/0); + } else if (approximate.compare("tanh") == 0) { + // "tanh" approximate + // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) + // Formula taken from: + // https://pytorch.org/docs/stable/generated/torch.nn.GELU.html + auto selfShape = selfType.getShape(); + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only static shape tensor types are currently supported for Tanh " + "approximation"); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, - /*shift=*/0); + auto numElem = std::accumulate(selfShape.begin(), selfShape.end(), 1, + std::multiplies()); + + Value half = tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 0.5f), + selfShape, selfElemTy) + .value(); + Value one = tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 1.0f), + selfShape, selfElemTy) + .value(); + Value three = tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 3.0f), + selfShape, selfElemTy) + .value(); + + // 0.044715 + Value magicNumber = + tosa::getConstTensor(rewriter, op, + SmallVector(numElem, 0.044715f), + selfShape, selfElemTy) + .value(); + + // From header: M_2_PI = 2 / pi + Value twoOverPi = + tosa::getConstTensor( + rewriter, op, + SmallVector(numElem, static_cast(M_2_PI)), selfShape, + selfElemTy) + .value(); + + // 0.5 * x + auto halfInput = rewriter.create(op->getLoc(), resultType, + half, self, /*shift=*/0); + + // sqrt(2/pi) + auto sqrtTwoOverPi = + rewriter.create(op->getLoc(), resultType, twoOverPi, half); + + // x^3 + auto inputPowThree = + rewriter.create(op->getLoc(), resultType, self, three); + + // 0.044715 * x^3 + auto inputPowThreeMul = + rewriter.create(op->getLoc(), resultType, magicNumber, + inputPowThree.getResult(), /*shift=*/0); + + // x + 0.044715 * x^3 + auto inputPowThreeMulAdd = rewriter.create( + op->getLoc(), resultType, self, inputPowThreeMul.getResult()); + + // sqrt(2/pi) * (x + 0.044715 * x^3) + auto sqrtTwoOverPiMul = rewriter.create( + op->getLoc(), resultType, sqrtTwoOverPi.getResult(), + inputPowThreeMulAdd.getResult(), /*shift=*/0); + + // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) + auto tanh = rewriter.create(op->getLoc(), resultType, + sqrtTwoOverPiMul.getResult()); + + // 1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) + auto tanhAdd = rewriter.create(op->getLoc(), resultType, one, + tanh.getResult()); + + rewriter.replaceOpWithNewOp( + op, resultType, halfInput.getResult(), tanhAdd.getResult(), + /*shift=*/0); + } else { + return rewriter.notifyMatchFailure(op, + "Unsupported approximation algorithm"); + } return success(); } @@ -2988,7 +3695,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3018,15 +3726,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); Value negOneHalf = tosa::getConstTensor(rewriter, op, -0.5f, {}, selfElemTy).value(); - Value inputSquared = rewriter.create( - loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0); + + if (mlir::tosa::EqualizeRanks(rewriter, loc, self, kAlphaHalf).failed() || + mlir::tosa::EqualizeRanks(rewriter, loc, self, negOneHalf).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + Value inputSquared = + rewriter.create(loc, selfType, self, self, /*shift=*/0); Value negHalfInputSquared = rewriter.create( loc, selfType, inputSquared, negOneHalf, /*shift=*/0); Value dinput = rewriter.create(loc, selfType, negHalfInputSquared); - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); - Value dinputInput = rewriter.create( - loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0); + Value cdf = buildUnitNormalCdf(rewriter, op, self, selfElemTy).value(); + Value dinputInput = + rewriter.create(loc, selfType, dinput, self, /*shift=*/0); Value dinputInputAlpha = rewriter.create( loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0); Value cdfExt = @@ -3045,7 +3759,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType) { return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3065,7 +3780,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value gradOutput = adaptor.getGradOutput(); - auto gradOutputType = dyn_cast(adaptor.getSelf().getType()); + auto gradOutputType = dyn_cast(gradOutput.getType()); Type gradOutputElemType = gradOutputType.getElementType(); @@ -3090,17 +3805,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value replace = tosa::getConstTensor(rewriter, op, 0, {}, selfElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, minVal) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, maxVal) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, gradOutput) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, replace).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + Type outType = getTypeConverter()->convertType(op.getType()); Value lesser = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - minVal, adaptor.getSelf()); + minVal, self); Value greater = rewriter.create( op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), - adaptor.getSelf(), maxVal); + self, maxVal); Value cmp = rewriter.create( op.getLoc(), @@ -3124,7 +3850,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( cast(typeConverter->convertType(op.getType())); auto indicesType = dyn_cast(indices.getType()); - if (!indicesType || !indicesType.getElementType().isa()) + if (!indicesType || !isa(indicesType.getElementType())) return rewriter.notifyMatchFailure( op, "Indices must be of integer tensor type"); @@ -3255,92 +3981,144 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenMaxDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); +template +class ConvertAtenMinMaxDimOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - auto indicesType = - dyn_cast(getTypeConverter()->convertType(op.getType(1))); - if (!indicesType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - auto selfElemType = selfType.getElementType(); - auto indicesElemType = indicesType.getElementType(); + const TypeConverter *typeConverter = this->getTypeConverter(); + auto indicesType = + dyn_cast(typeConverter->convertType(op.getType(1))); + if (!indicesType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - // Only statically deducible values are currently supported - int64_t dim; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); + auto selfElemType = selfType.getElementType(); + auto indicesElemType = indicesType.getElementType(); - dim = toPositiveDim(dim, selfType.getRank()); + // Only statically deducible values are currently supported + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); - if (!isValidDim(dim, selfType.getRank())) - return rewriter.notifyMatchFailure(op, "dim must be less than tensor rank"); + dim = toPositiveDim(dim, selfType.getRank()); - bool keepDim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) - return rewriter.notifyMatchFailure(op, "keepdim must be a Scalar constant"); + if (!isValidDim(dim, selfType.getRank())) + return rewriter.notifyMatchFailure(op, + "dim must be less than tensor rank"); - SmallVector reducedShape, prunedShape; - for (auto en : - llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { - if (static_cast(en.index()) == dim) { - reducedShape.push_back(1); - continue; + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure(op, + "keepdim must be a Scalar constant"); + + SmallVector reducedShape, prunedShape; + for (auto en : + llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { + if (static_cast(en.index()) == dim) { + reducedShape.push_back(1); + continue; + } + reducedShape.push_back(en.value()); + prunedShape.push_back(en.value()); } - reducedShape.push_back(en.value()); - prunedShape.push_back(en.value()); - } - auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); - auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); + auto dimAttr = rewriter.getIntegerAttr(rewriter.getI32Type(), dim); + auto prunedShapeAttr = rewriter.getDenseI64ArrayAttr(prunedShape); - Value reduceMax = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), - selfElemType), - adaptor.getSelf(), dimAttr); + Value reduceOp; + if constexpr (std::is_same() || + std::is_same()) { + // Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min + // and tosa.reduce_max + reduceOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), + selfElemType), + self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } else { + reduceOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), + selfElemType), + self, dimAttr); + } - Value argMax = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), - indicesElemType), - adaptor.getSelf(), dimAttr); + // To handle ReduceMinDim indices, we apply ArgMaxOp on the negate + // of the input tensor, which will return indices of input's min values + Value argMaxOp; + if constexpr (std::is_same()) { + Value negateOp = + rewriter.create(op->getLoc(), selfType, self); - if (argMax.getType() != indicesType) { - argMax = rewriter.create( - op->getLoc(), indicesType, argMax, - rewriter.getDenseI64ArrayAttr(reducedShape)); - } + // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax + argMaxOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + indicesElemType), + negateOp, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } else { + // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax + argMaxOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + indicesElemType), + self, dimAttr, /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } - if (!keepDim) { - reduceMax = rewriter.create( - op->getLoc(), - RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), - selfElemType), - reduceMax, prunedShapeAttr); - } + if (argMaxOp.getType() != indicesType) { + argMaxOp = rewriter.create( + op->getLoc(), indicesType, argMaxOp, + rewriter.getDenseI64ArrayAttr(reducedShape)); + } - rewriter.replaceOp(op, {reduceMax, argMax}); + if (!keepDim) { + reduceOp = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), + selfElemType), + reduceOp, prunedShapeAttr); + } - return success(); -} + rewriter.replaceOp(op, {reduceOp, argMaxOp}); + + return success(); + } +}; template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenSliceTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + if (op->use_empty()) { + rewriter.eraseOp(op); + return success(); + } + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); + auto outTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + if (!outTy) { + return rewriter.notifyMatchFailure(op, "output type must be ranked"); + } + if (outTy.hasStaticShape() && outTy.getNumElements() == 0) { + return rewriter.notifyMatchFailure(op, + "tosa.slice does not support zero size"); + } + // Only statically deducible values are currently supported int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -3351,16 +4129,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isValidDim(dim, selfType.getRank())) return rewriter.notifyMatchFailure(op, "dim must less than tensor rank"); + auto sizeOfDim = selfType.getDimSize(dim); + int64_t start; if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - if (start < 0) { - start = toPositiveDim(start, selfType.getShape()[dim]); - if (!isValidDim(start, selfType.getShape()[dim])) - return rewriter.notifyMatchFailure(op, "start is not a valid index"); - } - start = std::min(selfType.getShape()[dim], start); + start = toPositiveDim(start, selfType.getShape()[dim]); + start = std::clamp(start, (int64_t)0, sizeOfDim); int64_t end; if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { @@ -3373,32 +4149,44 @@ LogicalResult ConvertAtenOp::matchAndRewrite( end = toPositiveDim(end, selfType.getShape()[dim]); // support for end out of upper bound end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end); - - // FIXME: add support for start < 0 and end < start - if (end < start) - return rewriter.notifyMatchFailure(op, - "Currently unsupported: end < start"); + // Handle start > end + end = std::clamp(end, (int64_t)0, sizeOfDim); int64_t step; if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "step must be a Scalar constant"); - if (step != 1) - return rewriter.notifyMatchFailure( - op, "step value other than 1 is currently unsupported"); + if (sizeOfDim % step != 0) { + return rewriter.notifyMatchFailure(op, "size must be divisible by step"); + } - SmallVector startSlice(selfType.getRank(), 0); - SmallVector sizeSlice = - llvm::to_vector(makeShapeTorchCompatible(selfType.getShape())); + // We handle step by splitting the dimension dim into two dimensions, + // where the second one has size 'step'. + // E.g. to take slice with step 3 out of dim=0 of [6, 10], we first + // reshape into [2, 3, 10]. + SmallVector newShape{selfType.getShape()}; + newShape[dim] /= step; + newShape.insert(newShape.begin() + dim + 1, step); - startSlice[dim] = start; - sizeSlice[dim] = end - start; + auto reshaped = + tosa::reshapeTo(op->getLoc(), rewriter, adaptor.getSelf(), newShape); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), - rewriter.getDenseI64ArrayAttr(startSlice), - rewriter.getDenseI64ArrayAttr(sizeSlice)); + SmallVector startSlice(reshaped.getType().getRank(), 0); + + startSlice[dim] = start / step; + startSlice[dim + 1] = start % step; + + SmallVector sliceShape{outTy.getShape()}; + sliceShape.insert(sliceShape.begin() + dim + 1, 1); + + auto slice = rewriter.create( + op.getLoc(), outTy.cloneWith(sliceShape, outTy.getElementType()), + reshaped, rewriter.getDenseI64ArrayAttr(startSlice), + rewriter.getDenseI64ArrayAttr(sliceShape)); + + auto out = tosa::reshapeTo(op->getLoc(), rewriter, slice, outTy.getShape()); + rewriter.replaceOp(op, out); return success(); } @@ -3407,8 +4195,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenBroadcastToOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto selfType = dyn_cast(self.getType()); if (!selfType || !selfType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); @@ -3422,19 +4211,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector resultShape; if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape))) return rewriter.notifyMatchFailure(op, - "size must consist of Scalar constants"); + "Size must consist of Scalar constants"); + + int64_t inputRank = selfType.getRank(); + int64_t outputRank = resultShape.size(); + if (inputRank > outputRank) + return rewriter.notifyMatchFailure( + op, "Input tensor rank cannot be greater than output tensor rank"); + // Get the result type auto resultType = getTypeConverter()->convertType(op.getType()); SmallVector inputShape( makeShapeTorchCompatible(selfType.getShape())); + + // If input rank is smaller than output rank, we reshape the input tensor to + // be the same rank as the output tensor by prepending 1s to the input shape + SmallVector targetInputShape; + for (int64_t i = 0; i < outputRank - inputRank; i++) + targetInputShape.push_back(1); + targetInputShape.append(inputShape); + // Result dimension -1 means not changing the size of that dimension. // Adjust it by assigning its inputShape. - for (auto shape : llvm::enumerate(makeShapeTorchCompatible(inputShape))) { + for (auto shape : + llvm::enumerate(makeShapeTorchCompatible(targetInputShape))) { auto index = shape.index(); if (resultShape[index] == -1) resultShape[index] = shape.value(); } + + for (int64_t i = 0; i < outputRank; i++) { + if (targetInputShape[i] != resultShape[i] && targetInputShape[i] != 1) + return rewriter.notifyMatchFailure( + op, "Input and result shapes should be equal at each dimension or " + "input shape should be 1"); + } + // Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is // true then we can replace the op result with the input operand directly. if (llvm::equal(inputShape, resultShape)) { @@ -3442,52 +4255,42 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // since the input and result are of same shape. op.replaceAllUsesWith(op.getSelf()); rewriter.eraseOp(op); - return success(); - } else if (selfType.hasRank() && - (selfType.getRank() == (int64_t)resultShape.size() || - selfType.getRank() == 0)) { - // Right now to support limited cases where input and result shape are not - // equal, we can put a constraint that either the input should be of rank - // 0 or the rank of input tensor and result should be equal. And then we - // can check for broadcasting compatibility for the latter case. For - // broadcasting compatibility, either the shape of input and result should - // be equal at each dimenion or one of them should be 1. - if (selfType.getRank() != 0) { - for (unsigned i = 0; i < inputShape.size(); i++) { - if (inputShape[i] != resultShape[i] && inputShape[i] != 1 && - resultShape[i] != 1) { - return rewriter.notifyMatchFailure( - op, "unimplemented: either the shape of input and result should " - "be equal at each dimenion or one of them should be 1."); - } + } else { + // By using reshape and tile ops, support for input rank smaller than result + // rank is allowed. If the rank is smaller, we reshape the input to be the + // same rank as the result, then use tile to expand it. The way it was + // handled before involves adding the input tensor to a const zero tensor of + // output shape to utilize the innate broadcast feature of the TOSA add op. + // That poses the danger of sign bit flips for denormalized values. + // Basically, this approach to broadcast_to legalization allows for more + // flexibility in rank differences and also offers more safety. + Value reshapedInput = self; + if (!llvm::equal(inputShape, targetInputShape)) + reshapedInput = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(targetInputShape), + selfElemTy), + self, rewriter.getDenseI64ArrayAttr(targetInputShape)); + + SmallVector tileOpShape; + for (int64_t i = 0; i < outputRank; i++) { + if (targetInputShape[i] == 1) { + tileOpShape.push_back(resultShape[i]); + } else { + tileOpShape.push_back(1); } } - // If the above condition hold true then we can directly create a const - // zero tensor of shape same as the result shape. - SmallVector zeroTensorShape{resultShape}; + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape); - // create the 0 constant tensor - int64_t totalNumElements = 1; - for (auto dimSize : zeroTensorShape) { - totalNumElements = dimSize * totalNumElements; - } - // There is some danger here. For edge cases in floating point, x + 0 != x. - // The cases are denormalized values, which may get flushed, and -0 + 0 = - // +0. (sign bit flips). These are probably acceptable in the short term, - // but we should put a comment acknowledging the danger, as there isn't an - // op that avoids the denorm flushing. - Value zeroTensor = - tosa::getZerosLikeTensor(rewriter, op, resultType).value(); - - // Use add broadcast - rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf(), - zeroTensor); - return success(); + auto result = rewriter.create(op->getLoc(), resultType, + reshapedInput, tileOpMultiples); + + rewriter.replaceOp(op, {result.getResult()}); } - return rewriter.notifyMatchFailure( - op, - "unimplemented: broadcasts other than same rank or zero ranked tensor."); + + return success(); } template <> @@ -3578,81 +4381,174 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexSelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - // a = torch.tensor([[0, 1, 2, 3]]) - // a[..., 1:] = torch.tensor([4, 5, 6]) - // = a[..., 1:4] = torch.tensor([4, 5, 6]) - // = a[[0, 0, 0], [1, 2, 3]] = torch.tensor([4, 5, 6]) # tensor([[0, 4, 5, - // 6]]) = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input - // (torch.tensor([0, 0, 0]), torch.tensor([1, 2, - // 3])), # indicies torch.tensor([4, 5, 6])) # - // value - // = torch.ops.aten.index_put(torch.tensor([[0, 1, 2, 3]]), # input - // (None, torch.tensor([1, 2, 3]),),# indicies - // torch.tensor([4, 5, 6])) # value - // Not a tensor type. auto input = adaptor.getSelf(); - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) + auto inputType = dyn_cast(input.getType()); + if (!inputType) return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); + op, "Only RankedTensorType inputs are currently supported"); - auto fillValues = adaptor.getValues(); - auto valuesType = dyn_cast(adaptor.getValues().getType()); - if (!valuesType) - return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); + auto index = adaptor.getIndex(); + auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); - // Deal with torch.prim.ListConstruct of non const value to get the index - auto tensorList = op.getIndices(); - SmallVector tensorsTorchType; - if (!getListConstructElements(tensorList, tensorsTorchType)) - return op.emitError( - "unimplemented: the tensor list is not from list construct"); - auto indexTensors = getTypeConvertedValues( - rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType); + if (!indexType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType indices are currently supported"); - auto outType = getTypeConverter()->convertType(op.getType()); + auto inputShape = inputType.getShape(); + int inputRank = inputType.getRank(); - // convert list of indices with none into indices tensor without none - // indexTensors (none,[1,2,3]) -> ([0,0,0],[1,2,3]) - // ([[0],[0],[0]],[[1],[2],[3]])-> [[0,1],[0,2], [0,3]] - if (indexTensors.size() <= 1) { - return rewriter.notifyMatchFailure( - op, "Only support indexput with multiple index."); + if (indexType.getRank() == 0) { + indexShape = makeShapeTorchCompatible({1}); + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, indexType.getElementType()), index, + rewriter.getDenseI64ArrayAttr(indexShape)); } - SmallVector indicesTfConcatTensors; - SmallVector indexesRank; - SmallVector> indexesShape; - // concat index tensor into to indices tensor for concat + // Dynamic shape check + if (!inputType.hasStaticShape() || !indexType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "AtenIndexSelectOp: support for dynamic input " + "shape not implemented"); + + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + } + + // Get positive dim + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Value `dim` should be a torch constant int"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "Value `dim` is invalid"); + + // Get the output type + auto outType = getTypeConverter()->convertType(op.getType()); + + // Reshape and expand the index tensor to have same rank and same dimensions + // (except for the targeted dim) as the input + // + // For example: + // Input shape = (4, 5, 6) + // Index vector shape = (2) + // Targeted dim = 1 + // Reshaped and expanded index vector shape = (4, 2, 6) + // + // By reshaping and expanding the index vector, we can supply it into the + // gather op to mimic the functionality of aten.index_select + SmallVector indicesInputRankShape; + for (int64_t i = 0; i < inputRank; i++) { + if (i == dim) { + indicesInputRankShape.push_back(indexShape[0]); + } else { + indicesInputRankShape.push_back(1); + } + } + + auto indicesInputRankType = + RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape), + rewriter.getIntegerType(32)); + + auto reshapedIndices = rewriter.create( + op->getLoc(), indicesInputRankType, index, + rewriter.getDenseI64ArrayAttr(indicesInputRankShape)); + + SmallVector tileShape(indicesInputRankShape); + SmallVector expandedIndicesShape(indicesInputRankShape); + for (int64_t i = 0; i < inputRank; i++) { + if (tileShape[i] == 1 && i != dim) { + tileShape[i] = inputShape[i]; + expandedIndicesShape[i] = inputShape[i]; + } else { + tileShape[i] = 1; + } + } + + auto tileType = + RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape), + rewriter.getIntegerType(32)); + + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); + + auto expandedIndices = rewriter.create( + op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples); + + // convert torch style index and dim into tf style indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = tosa::convertTorchIndexToTfIndices( + rewriter, op, input, expandedIndices.getResult(), dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert TorchIndex To TfIndices failed"); + + // do the tf gathernd algorithm with tf style indices as input. + auto result = + tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value()); + + if (!result) { + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + } + rewriter.replaceOp(op, {result.value()}); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Not a tensor type. + auto input = adaptor.getSelf(); + auto selfType = dyn_cast(input.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + auto fillValues = adaptor.getValues(); + auto valuesType = dyn_cast(fillValues.getType()); + if (!valuesType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + // Deal with torch.prim.ListConstruct of non const value to get the index + // Index_put-like ops are now decomposed to aten.index_put.hacked_twin with + // stricter semantics, i.e., no None index in indices argument. + auto tensorList = op.getIndices(); + SmallVector tensorsTorchType; + if (!getListConstructElements(tensorList, tensorsTorchType)) + return op.emitError("Tensor list is not from list construct"); + auto indexTensors = getTypeConvertedValues( + rewriter, op->getLoc(), getTypeConverter(), tensorsTorchType); + + auto outType = getTypeConverter()->convertType(op.getType()); + + bool accumulate{false}; + if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) + return rewriter.notifyMatchFailure( + op, "Accumulate is not a constant bool value"); + + // No support for accumulate mode yet + if (accumulate) + return rewriter.notifyMatchFailure( + op, "Accumulate mode is not currently supported"); + + SmallVector indicesTfConcatTensors; + SmallVector indexesRank; + SmallVector> indexesShape; + + // concat index tensor into to indices tensor for concat for (size_t i = 0; i < indexTensors.size(); i++) { auto index = indexTensors[i]; - auto indexTorch = tensorsTorchType[i]; - // TODO add support for none index other than i==0, like (index0, None) - // (None, index1) - if (i == 0 && indexTorch.getType().isa()) { - // convert None to [0,0,0] - auto indexNext = indexTensors[i + 1]; - auto indexNextTorch = tensorsTorchType[i + 1]; - if (indexNextTorch.getType().isa()) { - return rewriter.notifyMatchFailure( - op, "Multiple None index is not support for now."); - } - auto indexNextType = dyn_cast(indexNext.getType()); - auto indexNextShape = indexNextType.getShape(); - - int64_t size = 1; - for (auto s : indexNextShape) - size *= s; - SmallVector values(size, i); - index = - tosa::getConstTensor(rewriter, op, values, indexNextShape) - .value(); - } auto indexType = dyn_cast(index.getType()); auto indexShape = indexType.getShape(); @@ -3660,20 +4556,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indexesRank.push_back(indexType.getRank()); // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { + if (indexType.getElementType() != rewriter.getIntegerType(32)) index = rewriter.create( op->getLoc(), RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } // Expand last dim of index to tf indices [3] -> [3,1] // convert [0,0,0] to [[0],[0],[0]] SmallVector indiceShapeOneDim; - for (auto shape : indexShape) { + for (auto shape : indexShape) indiceShapeOneDim.push_back(shape); - } indiceShapeOneDim.push_back(1); + auto indicesTfOneDim = tosa::CreateOpAndInfer( rewriter, op->getLoc(), RankedTensorType::get(indiceShapeOneDim, rewriter.getIntegerType(32)), @@ -3690,7 +4585,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (auto indexShapeOneDim : indexesShape) { if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { return rewriter.notifyMatchFailure( - op, "unimplemented: Only support multi indexes with same shape"); + op, "Only support indices with same shape"); } } @@ -3704,24 +4599,47 @@ LogicalResult ConvertAtenOp::matchAndRewrite( GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure(op, - "Convert TorchIndex To TfIndices fail."); - } - // do the tf scatterNd algorithm with tf style indices as input, algorithm - // mostly take from convertGatherNdOp. + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index to TensorFlow indices failed"); + auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, indicesTf.getResult(), fillValues); - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert ScatterNdOp fail for index tensor."); - } + if (!result) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + rewriter.replaceOp(op, {result.value()}); return success(); } +std::optional wrapNegativeIndices(Value index, int maxIndex, + Operation *op, + ConversionPatternRewriter &rewriter) { + + auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}).value(); + auto maxIndexValue = + tosa::getConstTensor(rewriter, op, maxIndex, {}).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), index, zeroValue) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), index, maxIndexValue) + .failed()) + return std::nullopt; + + auto indexType = dyn_cast(index.getType()); + + auto wrappedIndicesOp = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), indexType, maxIndexValue, index); + auto boolType = indexType.clone(rewriter.getIntegerType(1)); + auto isNegativeIndices = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, index); + return tosa::CreateOpAndInfer(rewriter, op->getLoc(), + indexType, isNegativeIndices, + wrappedIndicesOp, index); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor, @@ -3753,6 +4671,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outType = getTypeConverter()->convertType(op.getType()); + Operation *indicesTf; + // Support for multiple indexes if (indexTensors.size() > 1) { // t[i, i] @@ -3786,6 +4706,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( index); } + index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op, + rewriter) + .value(); // Expand last dim of index to tf indices [2,3] -> [2,3,1] SmallVector indiceShapeOneDim; for (auto shape : indexShape) { @@ -3801,13 +4724,132 @@ LogicalResult ConvertAtenOp::matchAndRewrite( indicesTfConcatTensors.push_back(indicesTfOneDim.getResult()); } - // Right now only support multiple indexes with same shape - // TODO for different shape multiple indexes, add broadcast_to for small - // shape + auto getRankExtendedShape = + [](SmallVector inputShape, + SmallVector maxRank1DimShape) -> SmallVector { + SmallVector rankExtendedShape(maxRank1DimShape); + auto inputRank = inputShape.size(); + auto maxRank = maxRank1DimShape.size(); + auto startIdx = maxRank - inputRank; + for (size_t i = startIdx; i < maxRank; i++) { + rankExtendedShape[i] = inputShape[i - startIdx]; + } + return rankExtendedShape; + }; + + bool hasDiffShapedIndexes = false; for (auto indexShapeOneDim : indexesShape) { if (!llvm::equal(indexesShape[0], indexShapeOneDim)) { - return rewriter.notifyMatchFailure( - op, "unimplemented: Only support multi indexes with same shape"); + hasDiffShapedIndexes = true; + break; + } + } + + if (hasDiffShapedIndexes) { + int64_t maxRank = 1; + for (auto idxRank : indexesRank) { + if (idxRank > maxRank) + maxRank = idxRank; + } + // Tensor shape of max rank, each dim being 1 + SmallVector maxRank1DimShape; + for (int i = 0; i < maxRank; i++) + maxRank1DimShape.push_back(1); + // Tensor shape of max rank, each dim being the max dim. + SmallVector maxRankMaxDimShape(maxRank1DimShape); + + auto updateMaxRankMaxDimShape = + [&](SmallVector broadcastedShape) -> LogicalResult { + for (size_t i = 0; i < maxRankMaxDimShape.size(); i++) { + // check for malformed index tensors + if (broadcastedShape[i] != 1 && maxRankMaxDimShape[i] != 1 && + maxRankMaxDimShape[i] != broadcastedShape[i]) { + return failure(); + } + if (broadcastedShape[i] > maxRankMaxDimShape[i]) + maxRankMaxDimShape[i] = broadcastedShape[i]; + } + return success(); + }; + + for (size_t i = 0; i < indexesRank.size(); i++) { + // Reshape all index tensors to same maxRank + auto idxRank = indexesRank[i]; + auto unreshapedIdxTensor = indicesTfConcatTensors[i]; + SmallVector broadcastedShape = + getRankExtendedShape(indexesShape[i], maxRank1DimShape); + + if (idxRank < maxRank) { + auto idxType = + dyn_cast(indicesTfConcatTensors[i].getType()); + // indicesTfConcatTensors has a trailing [1] dim for the final concat. + auto broadcastedShapeTf(broadcastedShape); + broadcastedShapeTf.push_back(1); + auto reshapeOutputTy = RankedTensorType::get( + broadcastedShapeTf, idxType.getElementType()); + // Update the tensor array with the max rank-extended form + indicesTfConcatTensors[i] = rewriter.create( + op->getLoc(), reshapeOutputTy, unreshapedIdxTensor, + rewriter.getDenseI64ArrayAttr(broadcastedShapeTf)); + } + + // Construct the max rank broadcasted form of all index tensors with + // each index tensor. + if (updateMaxRankMaxDimShape(broadcastedShape).failed()) { + return rewriter.notifyMatchFailure( + op, "Malformed index tensors that have mismatched dim shapes"); + } + + // Every index now has the same rank but not yet same shape until + // tosa.tile below. + indexesShape[i] = broadcastedShape; + indexesRank[i] = maxRank; + } + + auto getTileOpShape = [&](SmallVector indexShape, + SmallVector &tileOpShape) -> bool { + bool needsTiling = false; + for (size_t i = 0; i < indexShape.size(); i++) { + if (1 == indexShape[i]) { + tileOpShape.push_back(maxRankMaxDimShape[i]); + needsTiling = true; + } else { + tileOpShape.push_back(1); + } + } + return needsTiling; + }; + + // Use tosa.tile to broadcast in multiple dims so all index tensors have + // the same shape. This materializes new tensors. + for (size_t i = 0; i < indexesRank.size(); i++) { + SmallVector tileOpShape; + bool needsTiling = getTileOpShape(indexesShape[i], tileOpShape); + + if (needsTiling) { + auto idxType = + dyn_cast(indicesTfConcatTensors[i].getType()); + + // indicesTfConcatTensors has a trailing [1] dim for the final concat. + auto maxRankMaxDimShapeTf(maxRankMaxDimShape); + maxRankMaxDimShapeTf.push_back(1); + + auto tileOpShapeTf(tileOpShape); + tileOpShapeTf.push_back(1); + + auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf, + idxType.getElementType()); + auto reshapedIdxTensor = indicesTfConcatTensors[i]; + + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf); + + indicesTfConcatTensors[i] = rewriter.create( + op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples); + } + + // Every index tensor now has the same rank and shape + indexesShape[i] = maxRankMaxDimShape; } } @@ -3815,49 +4857,40 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indicesShapeConcat = indexesShape[0]; uint64_t lastDim = indexesRank[0]; indicesShapeConcat.push_back(indicesTfConcatTensors.size()); - auto indicesTf = tosa::CreateOpAndInfer( + indicesTf = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)), indicesTfConcatTensors, lastDim); - if (!indicesTf) { - return rewriter.notifyMatchFailure( - op, "Convert TorchIndex To TfIndices fail."); - } - // do the tf gathernp algorithm with tf style indices as input. - auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + } else { - if (!result) { - return rewriter.notifyMatchFailure( - op, "Convert GatherNdOp fail for index tensor."); + // Single index + auto index = indexTensors[0]; + auto indexType = dyn_cast(index.getType()); + auto indexShape = indexType.getShape(); + // index i64 to i32 for tosa compatible + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), + index); } - rewriter.replaceOp(op, {result.value()}); - return success(); - } - - // Support for multiple index - auto index = indexTensors[0]; - auto indexType = dyn_cast(index.getType()); - auto indexShape = indexType.getShape(); - // index i64 to i32 for tosa compatible - if (indexType.getElementType() != rewriter.getIntegerType(32)) { - index = rewriter.create( - op->getLoc(), - RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); - } + index = + wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter) + .value(); - // Expand last dim of index to tf indices [2,3] -> [2,3,1] - SmallVector indicesShape; - for (auto shape : indexShape) { - indicesShape.push_back(shape); + // Expand last dim of index to tf indices [2,3] -> [2,3,1] + SmallVector indicesShape; + for (auto shape : indexShape) { + indicesShape.push_back(shape); + } + indicesShape.push_back(1); + indicesTf = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, + rewriter.getDenseI64ArrayAttr(indicesShape)); } - indicesShape.push_back(1); - auto indicesTf = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index, - rewriter.getDenseI64ArrayAttr(indicesShape)); if (!indicesTf) { return rewriter.notifyMatchFailure(op, @@ -3865,7 +4898,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // do the tf gathernp algorithm with tf style indices as input. auto result = tosa::convertGatherNdOp(rewriter, op, outType, input, - indicesTf.getResult()); + indicesTf->getResult(0)); if (!result) { return rewriter.notifyMatchFailure( @@ -3876,76 +4909,280 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// Legalization for aten.scatter.src template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenAbsOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenScatterSrcOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) + auto input = adaptor.getSelf(); + auto inputType = dyn_cast(input.getType()); + if (!inputType) return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); + op, "Only RankedTensorType inputs are currently supported"); - auto outType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); + auto inputShape = inputType.getShape(); + auto paramsRank = inputType.getRank(); - return success(); -} + auto index = adaptor.getIndex(); + auto indexType = dyn_cast(index.getType()); + if (!indexType) + return rewriter.notifyMatchFailure( + op, "Only RankedTensorType indices are currently supported"); -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenWhereSelfOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { + // Check `index` and `input` param should have the same rank + if (indexType.getRank() != paramsRank) + return rewriter.notifyMatchFailure( + op, "Params index and input should have the same rank"); - // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) + auto indexShape = indexType.getShape(); + + auto src = adaptor.getSrc(); + auto srcType = dyn_cast(src.getType()); + if (!srcType) return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - auto condType = dyn_cast(adaptor.getCondition().getType()); - if (!condType) + op, "Only RankedTensorType sources are currently supported"); + + // Check `src` and `input` param should have the same rank + if (srcType.getRank() != paramsRank) + return rewriter.notifyMatchFailure( + op, "Src and input should have the same rank"); + + auto srcShape = srcType.getShape(); + + // Dynamic shape check + if (!inputType.hasStaticShape() || !indexType.hasStaticShape() || + !srcType.hasStaticShape()) return rewriter.notifyMatchFailure( - op, "Only tensor types condition are currently supported"); + op, "Support for dynamic shape not implemented"); + + // index i64 to i32 for tosa compatitable + if (indexType.getElementType() != rewriter.getIntegerType(32)) { + index = rewriter.create( + op->getLoc(), + RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index); + } + + // Get positive dim + int64_t dim{0}; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Dim value should be a constant int"); + + dim = toPositiveDim(dim, paramsRank); + if (!isValidDim(dim, paramsRank)) + return rewriter.notifyMatchFailure(op, "Dim is invalid"); + + // It is also required that index.size(d) <= src.size(d) for all dimensions d, + // and that index.size(d) <= self.size(d) for all dimensions d != dim + for (int64_t d = 0; d < paramsRank; d++) { + if (d != dim) { + if (indexShape[d] > srcShape[d] || indexShape[d] > inputShape[d]) + return rewriter.notifyMatchFailure( + op, "Index size should be smaller or equal to src or input size " + "for all dimensions d != dim"); + } + } + // Get the output type auto outType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp( - op, outType, adaptor.getCondition(), adaptor.getSelf(), - adaptor.getOther()); + // convert PyTorch style index and dim into TensorFlows tyle indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = + tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index and dim to TensorFlow indices failed"); + + // Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as + // input. + auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, + indicesTf.value(), src); + + if (!result) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + + rewriter.replaceOp(op, {result.value()}); return success(); } +// Legalization for aten.slice_scatter template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLeTensorOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenSliceScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); - if (!selfType) + auto input = adaptor.getSelf(); + auto inputType = dyn_cast(input.getType()); + if (!inputType) return rewriter.notifyMatchFailure( - op, "Only tensor types input are currently supported"); - auto otherType = dyn_cast(adaptor.getOther().getType()); - if (!otherType) + op, "Only RankedTensorType inputs are currently supported"); + + auto inputShape = inputType.getShape(); + auto paramsRank = inputType.getRank(); + + auto src = adaptor.getSrc(); + auto srcType = dyn_cast(src.getType()); + if (!srcType) return rewriter.notifyMatchFailure( - op, "Only tensor types condition are currently supported"); + op, "Only RankedTensorType sources are currently supported"); - auto outType = getTypeConverter()->convertType(op.getType()); + // Check `src` and `input` param should have the same rank + if (srcType.getRank() != paramsRank) + return rewriter.notifyMatchFailure( + op, "Src and input should have the same rank"); - auto greaterOp = rewriter.create( - op.getLoc(), outType, adaptor.getSelf(), adaptor.getOther()); + auto srcShape = srcType.getShape(); - rewriter.replaceOpWithNewOp(op, outType, - greaterOp.getOutput()); + // Dynamic shape check + if (!inputType.hasStaticShape() || !srcType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Support for dynamic shape not implemented"); - return success(); -} + // Get positive dim + int64_t dim{0}; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Dim value should be a constant int"); -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenIscloseOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // check args + dim = toPositiveDim(dim, paramsRank); + if (!isValidDim(dim, paramsRank)) + return rewriter.notifyMatchFailure(op, "Dim is invalid"); + + // Get start, end, and step params + // If start and end params are not specified, assign them to 0 and + // inputShape[dim], respectively. + int64_t start{0}; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure(op, + "Start value should be a constant int"); + if (start < 0) + start += inputShape[dim]; + + int64_t end{inputShape[dim]}; + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure(op, + "End value should be a constant int"); + if (end < 0) + end += inputShape[dim]; + + if (end > inputShape[dim]) + end = inputShape[dim]; + + if (start >= end) + return rewriter.notifyMatchFailure( + op, "Start value greater than end value not supported"); + + int64_t step{1}; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure(op, + "Step value should be a constant int"); + + // Create PyTorch style scatter index based on start, end, and step values + int64_t outerRepeat{1}, innerRepeat{1}; + for (int64_t i = 0; i < dim; i++) + outerRepeat *= srcShape[i]; + + for (int64_t i = dim + 1; i < paramsRank; i++) + innerRepeat *= srcShape[i]; + + SmallVector indexVec; + for (int64_t i = 0; i < outerRepeat; i++) { + for (int32_t indexVal = start; indexVal < end; indexVal += step) { + for (int64_t j = 0; j < innerRepeat; j++) { + indexVec.push_back(indexVal); + } + } + } + + Value index = + tosa::getConstTensor(rewriter, op, indexVec, srcShape).value(); + + // Get the output type + auto outType = getTypeConverter()->convertType(op.getType()); + + // convert PyTorch style index and dim into TensorFlows tyle indices + // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> + auto indicesTf = + tosa::convertTorchIndexToTfIndices(rewriter, op, input, index, dim); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index and dim to TensorFlow indices failed"); + + // Perform the TensorFlow ScatterNd algorithm with TensorFlow style indices as + // input. + auto result = tosa::convertScatterNdOp(rewriter, op, outType, input, + indicesTf.value(), src); + + if (!result) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + + rewriter.replaceOp(op, {result.value()}); + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenAbsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Not a tensor type. + auto selfType = dyn_cast(adaptor.getSelf().getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only tensor types input are currently supported"); + + auto outType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenWhereSelfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a tensor type. + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only tensor types inputs are currently supported"); + + auto cond = adaptor.getCondition(); + auto condType = dyn_cast(cond.getType()); + if (!condType) + return rewriter.notifyMatchFailure( + op, "Only tensor types conditions are currently supported"); + + auto other = adaptor.getOther(); + auto otherType = dyn_cast(other.getType()); + if (!otherType) + return rewriter.notifyMatchFailure( + op, "Only tensor types inputs are currently supported"); + + auto outType = + dyn_cast(getTypeConverter()->convertType(op.getType())); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, self).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), cond, other).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, other).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + rewriter.replaceOpWithNewOp(op, outType, cond, self, other); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIscloseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // check args double rtol, atol; bool equalNan; if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol))) @@ -3957,33 +5194,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: equal_nan is expected to be false"); // check tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); - auto otherType = dyn_cast(adaptor.getOther().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); + auto other = adaptor.getOther(); + auto otherType = dyn_cast(other.getType()); if (!selfType || !otherType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); if (!selfType.hasStaticShape() || !otherType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only tensor types with static shape are supported"); - if (!selfType.getElementType().isa() || - !otherType.getElementType().isa()) { + if (!isa(selfType.getElementType()) || + !isa(otherType.getElementType())) { return rewriter.notifyMatchFailure( op, "unimplemented: only FP element type is supported"); } + auto rtolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto atolConstOp = + tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, other).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, rtolConstOp) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, atolConstOp) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + // Reinitialize selfType and otherType after equalizing ranks + selfType = dyn_cast(self.getType()); + otherType = dyn_cast(other.getType()); - auto rhsSubOp = rewriter.create( - op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther()); + auto rhsSubOp = + rewriter.create(op->getLoc(), selfType, self, other); auto rhsAbsOp = rewriter.create(op->getLoc(), selfType, rhsSubOp); - auto lhsAbsOp = - rewriter.create(op->getLoc(), otherType, adaptor.getOther()); - auto rtolConstOp = - tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(rtol)); + auto lhsAbsOp = rewriter.create(op->getLoc(), otherType, other); auto mulOp = rewriter.create(op->getLoc(), otherType, rtolConstOp, lhsAbsOp, /*shift=*/0); - auto atolConstOp = - tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast(atol)); auto addOp = rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); @@ -4047,10 +5297,127 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "max attr should be a torch constant"); } + // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp auto outType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf(), - min_int, max_int, min_fp, max_fp); + rewriter.replaceOpWithNewOp( + op, outType, adaptor.getSelf(), min_int, max_int, min_fp, max_fp, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + + return success(); +} + +// Legalization for aten.clamp.Tensor +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenClampTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // We are not using tosa.clamp to lower aten.clamp.Tensor, as + // aten.clamp.Tensor's min and max attributes are tensors that can have size + // greater than 1, which is not compatible with tosa.clamp. + // + // Instead, we use the following formula: + // yi = min(max(xi, min_valuei), max_valuei) + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + // Get min tensor. If None, there is no lower bound. + Value min; + if (succeeded(checkNotNone(rewriter, op, adaptor.getMin()))) { + min = adaptor.getMin(); + } else { + min = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::lowest(), {}, + selfElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 8: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), {}) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), + {}) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::min(), + {}) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + } + + // Get max tensor. If None, there is no upper bound. + Value max; + if (succeeded(checkNotNone(rewriter, op, adaptor.getMax()))) { + max = adaptor.getMax(); + } else { + max = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), {}, + selfElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 8: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), {}) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), + {}) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, std::numeric_limits::max(), + {}) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + } + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, min).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, max).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + self = tosa::promoteType(rewriter, self, resultType); + min = tosa::promoteType(rewriter, min, resultType); + max = tosa::promoteType(rewriter, max, resultType); + + // max(xi, min_valuei) + // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum + auto minThresholdCheck = rewriter.create( + op->getLoc(), resultType, self, min, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + + // yi = min(max(xi, min_valuei), max_valuei) + // Use default NaN Propagation mode "PROPAGATE" for tosa.minimum + auto result = rewriter.create( + op->getLoc(), resultType, minThresholdCheck, max, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + + rewriter.replaceOp(op, result); return success(); } @@ -4060,9 +5427,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { const TypeConverter *typeConverter = this->getTypeConverter(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); // At this point all tensors should have value semantics, and hence the // `layout` check can be ignored. @@ -4070,7 +5436,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // TODO: Add support for pin_memory features. // The pin_memory should be either `False` or `none`. bool pinMemory; - if (!op.getPinMemory().getType().isa() && + if (!isa(op.getPinMemory().getType()) && (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) || pinMemory)) { return rewriter.notifyMatchFailure( @@ -4164,10 +5530,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( }; const auto isIntType = - resultType.getElementType().dyn_cast_or_null(); + dyn_cast_or_null(resultType.getElementType()); const auto isDoubleType = - resultType.getElementType().dyn_cast_or_null(); + dyn_cast_or_null(resultType.getElementType()); auto maybeResult = [&]() -> std::optional { // Integer output type, and start / end / range are all integers. @@ -4220,9 +5586,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { const TypeConverter *typeConverter = this->getTypeConverter(); - RankedTensorType resultType = - typeConverter->convertType(op->getResult(0).getType()) - .cast(); + TensorType resultType = + cast(typeConverter->convertType(op->getResult(0).getType())); + + if (!resultType.hasRank()) + return rewriter.notifyMatchFailure(op, "expected ranked tensor"); // Only supports integer operand type, because for the floating point operand // type result tensor has to be of type `f64` which is not supported in the @@ -4325,7 +5693,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( @@ -4338,9 +5706,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( "memory_format is supported"); } - auto resultTy = getTypeConverter() - ->convertType(op.getResult().getType()) - .cast(); + auto resultTy = cast( + getTypeConverter()->convertType(op.getResult().getType())); Value result; if (failed(tosa::tosaCastTensorToType(rewriter, op, adaptor.getSelf(), @@ -4351,56 +5718,100 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenRemainderScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { +template +class ConvertAtenRemainderFmodOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { - Value self = adaptor.getSelf(); - auto selfTy = cast(self.getType()); + Value self = adaptor.getSelf(); + auto selfTy = cast(self.getType()); - if (!selfTy) - return rewriter.notifyMatchFailure( - op, "Only ranked tensor types supported in TOSA Remainder"); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Remainder/Fmod"); - auto outType = - cast(getTypeConverter()->convertType(op.getType())); + auto outType = + cast(this->getTypeConverter()->convertType(op.getType())); - Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) - return rewriter.notifyMatchFailure( - op, "Only floating-point or integer datatype legalization supported"); + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); - Value otherTensor; - Value other = op.getOther(); - if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor, - outElemTy, {}))) - return rewriter.notifyMatchFailure( - op, "Currently only scalar constants are supported for " - "conversion in TOSA Remainder operation"); - - if (selfTy.getElementType() != outElemTy) - self = rewriter.create(op.getLoc(), outType, self); - - auto divTensor = self; - if (isa(outElemTy)) { - auto otherTensorReciprocal = rewriter.create( - op.getLoc(), otherTensor.getType(), otherTensor); - divTensor = rewriter.create( - op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); - divTensor = rewriter.create(op.getLoc(), outType, divTensor); - } else { - divTensor = rewriter.create(op.getLoc(), outType, self, - otherTensor); - } + Value otherTensor; + if constexpr (std::is_same()) { + Value other = op.getOther(); + if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor, + outElemTy, {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA Remainder/Fmod operation"); + } else { + otherTensor = adaptor.getOther(); + auto otherTy = cast(otherTensor.getType()); + + if (!otherTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types supported in TOSA Remainder/Fmod"); + } - auto mulTensor = - rewriter.create(op.getLoc(), outType, otherTensor, divTensor, - /*shift=*/0); - rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, otherTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + constexpr bool isRemainderOp = + std::is_same() || + std::is_same() || + std::is_same(); + + if (selfTy.getElementType() != outElemTy) + self = rewriter.create(op.getLoc(), outType, self); + + Value divTensor; + if (isRemainderOp) { + // torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b + if (isa(outElemTy)) { + auto otherTensorReciprocal = rewriter.create( + op.getLoc(), otherTensor.getType(), otherTensor); + divTensor = rewriter.create( + op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0); + divTensor = + rewriter.create(op.getLoc(), outType, divTensor); + } else { + divTensor = + floorIntDiv(rewriter, op, outType, self, otherTensor).value(); + } + } else { + // torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b + if (isa(outElemTy)) { + divTensor = truncFloatDiv(rewriter, op, outType, self, otherTensor); + } else { + // TOSA IntDiv requires inputs to be i32 + auto i32Type = RankedTensorType::get(outType.getShape(), + rewriter.getIntegerType(32)); + self = tosa::promoteType(rewriter, self, i32Type); + otherTensor = tosa::promoteType(rewriter, otherTensor, i32Type); - return success(); -} + auto intDivTensor = rewriter.create( + op->getLoc(), i32Type, self, otherTensor); + + divTensor = tosa::promoteType(rewriter, intDivTensor, outType); + } + } + + auto mulTensor = rewriter.create(op.getLoc(), outType, + otherTensor, divTensor, + /*shift=*/0); + rewriter.replaceOpWithNewOp(op, outType, self, mulTensor); + + return success(); + } +}; template class ConvertAtenPoolingBaseOp : public OpConversionPattern { @@ -4430,9 +5841,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } else { int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; - if (ceilMode && (dimSize % stride != 0)) - return dimSize / stride + 2; - return dimSize / stride + 1; + int64_t outputDim = dimSize / stride + 1; + if (ceilMode && (dimSize % stride != 0) && + (outputDim * stride < inputDim + padBefore)) + outputDim++; + return outputDim; } } @@ -4506,9 +5919,11 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { std::is_same::value, "Expected either tosa::MaxPool2dOp or tosa::AvgPool2dOp"); if constexpr (std::is_same::value) { + // Use default NaN Propagation mode "PROPAGATE" for tosa.max_pool2d pooledOutput = rewriter - .create(op->getLoc(), outputTy, input, kernel, - stride, pad) + .create( + op->getLoc(), outputTy, input, kernel, stride, pad, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) .getResult(); } else if constexpr (std::is_same::value) { TypeAttr accType; @@ -4525,11 +5940,25 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { ConvertAtenPoolingBaseOp::transposePoolingOutputToChw( op, rewriter, pooledOutput); - rewriter.replaceOpWithNewOp( - op, + Value result = transposedOutput; + auto resultTy = dyn_cast( OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - transposedOutput); + op.getType())); + + if constexpr (std::is_same() || + std::is_same()) { + auto resultShape = resultTy.getShape(); + auto resultElemTy = resultTy.getElementType(); + + result = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(resultShape), + resultElemTy), + transposedOutput, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + } + + rewriter.replaceOpWithNewOp(op, resultTy, result); return success(); } @@ -4676,6 +6105,12 @@ static LogicalResult getOutputTypeAndPoolingParameters( return rewriter.notifyMatchFailure( op, "Non-const kernel_size for pooling op unsupported"); + // Expand kernel size parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + kernelSizeInts.push_back(1); + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strideInts))) return rewriter.notifyMatchFailure( op, "Non-const stride for pooling op unsupported"); @@ -4683,13 +6118,46 @@ static LogicalResult getOutputTypeAndPoolingParameters( // list during import. For such a case, the stride value is the kernel size. // See: // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d - if (strideInts.empty()) + if (strideInts.empty()) { strideInts.assign(kernelSizeInts); + } else { + // Expand stride parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + strideInts.push_back(1); + } if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingInts))) return rewriter.notifyMatchFailure( op, "Non-const padding factor for pooling op unsupported"); + // Expand padding parameter to size 2 to be compatible with + // tosa::MaxPool2dOp or tosa::AvgPool2dOp + if constexpr (std::is_same() || + std::is_same()) + paddingInts.push_back(0); + + if constexpr (std::is_same() || + std::is_same()) { + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answer (SWA) when the `count_include_pad` value is `true.` + // + // Note: We need to check for `count_include_pad` only when the `padding` + // value is non-zero. + bool countIncludePad; + if ((paddingInts[0] != 0 || paddingInts[1] != 0) && + (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + + countIncludePad)) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool " + "`count_include_pad` value should be `False`."); + } + } + SmallVector padArr = {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); @@ -4745,6 +6213,68 @@ class ConvertAtenMaxPool2dOp } }; +// Legalization for aten.max_pool1d +class ConvertAtenMaxPool1dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenMaxPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { + auto self = adaptor.getSelf(); + + // Not a RankedTensorType + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type inputs are supported"); + auto selfShape = selfTy.getShape(); + + // Expected a rank 3 input tensor + if (selfTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Input tensor for MaxPool1d should have rank 3"); + + // Unsqueeze input tensor to rank 4 to be compatible with tosa::MaxPool2dOp + SmallVector rank4Shape(selfShape); + rank4Shape.push_back(1); + auto reshapedSelf = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + + SmallVector dilationArray; + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationArray))) + return rewriter.notifyMatchFailure( + op, "Non-const dilation for pooling op unsupported."); + // TOSA pooling only supports unit dilation. + if (dilationArray[0] > 1) + return rewriter.notifyMatchFailure( + op, "Cannot process non-unit pooling dilation."); + + // Expand dilation to size 2 to be compatible with tosa::MaxPool2dOp + dilationArray.push_back(1); + + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy, + kernel, stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult()); + + return success(); + } +}; + class ConvertAtenAvgPool2dOp : public ConvertAtenPoolingBaseOp { public: @@ -4755,6 +6285,16 @@ class ConvertAtenAvgPool2dOp DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { + + // Currently, we can not represent `divisor_override` with the existing TOSA + // AvgPool2d specification. Without the below check, we produce silent wrong + // answers (SWA) when the `divisor_override` value is other than `None.` + if (!isa(op.getDivisorOverride().getType())) { + return rewriter.notifyMatchFailure( + op, "Unsupported `divisor_override` value, for tosa AvgPool2dOp " + "`divisor_override` value should be `None`."); + } + SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters( @@ -4771,6 +6311,56 @@ class ConvertAtenAvgPool2dOp } }; +// Legalization for aten.avg_pool1d +class ConvertAtenAvgPool1dOp + : public ConvertAtenPoolingBaseOp { +public: + using ConvertAtenPoolingBaseOp::ConvertAtenPoolingBaseOp; + LogicalResult processInputs(AtenAvgPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, Value &input, + DenseI64ArrayAttr &kernel, + DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, + Type &outputTy) const override { + auto self = adaptor.getSelf(); + + // Not a RankedTensorType + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type inputs are supported"); + auto selfShape = selfTy.getShape(); + + // Expected a rank 3 input tensor + if (selfTy.getRank() != 3) + return rewriter.notifyMatchFailure( + op, "Input tensor for AvgPool1d should have rank 3"); + + // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp + SmallVector rank4Shape(selfShape); + rank4Shape.push_back(1); + auto reshapedSelf = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + + SmallVector dilationArray{1, 1}; + if (failed(getOutputTypeAndPoolingParameters( + op, rewriter, reshapedSelf.getResult(), dilationArray, outputTy, + kernel, stride, pad))) + return rewriter.notifyMatchFailure( + op, "invalid pooling parameters or input type"); + + // Transpose to xHWC + input = ConvertAtenPoolingBaseOp:: + transposePoolingInputToHwc(op, rewriter, reshapedSelf.getResult()); + + return success(); + } +}; + // Ref: Error checking based on the Torch to LinAlg lowering template class ConvertAtenConstPatternOp : public OpConversionPattern { @@ -4781,9 +6371,9 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType) return rewriter.notifyMatchFailure(op, @@ -4836,34 +6426,66 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { }; template -class ConvertAtenFillScalarOp : public OpConversionPattern { +class ConvertAtenFillOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( op, "Only Tensor types with static shapes are currently supported"); Type outElemTy = outType.getElementType(); - if (!outElemTy.isIntOrFloat()) { + if (!outElemTy.isIntOrFloat()) return rewriter.notifyMatchFailure( op, "Only floating-point or integer datatype legalization supported"); + + Value fillValueTargetTensor; + if constexpr (std::is_same()) { + // Reshape value tensor to have same rank and shape as input + auto inputRank = + cast(adaptor.getSelf().getType()).getRank(); + + auto fillValue = adaptor.getValue(); + auto fillValueType = dyn_cast(fillValue.getType()); + if (!fillValueType) + return rewriter.notifyMatchFailure(op, "Fill value is not a tensor"); + auto fillValueElemTy = fillValueType.getElementType(); + + SmallVector fillValueMatchedInputRankShape(inputRank, 1); + + auto fillValueMatchedInputRankType = RankedTensorType::get( + makeShapeTorchCompatible(fillValueMatchedInputRankShape), + fillValueElemTy); + + auto fillValueMatchedInputRankTensor = rewriter.create( + op->getLoc(), fillValueMatchedInputRankType, fillValue, + rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape()); + + fillValueTargetTensor = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), + fillValueElemTy), + fillValueMatchedInputRankTensor.getResult(), tileOpMultiples); + } else { + if (failed(torchScalarToTosaTensor( + rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy, + makeShapeTorchCompatible(outType.getShape())))) + return rewriter.notifyMatchFailure( + op, "Fill value must be a scalar constant"); } - Value constOp; - if (failed(torchScalarToTosaTensor( - rewriter, op, op.getValue(), constOp, outElemTy, - makeShapeTorchCompatible(outType.getShape())))) - return rewriter.notifyMatchFailure( - op, "Supplied value must be a Scalar constant"); - rewriter.replaceOpWithNewOp(op, outType, constOp); + rewriter.replaceOpWithNewOp(op, outType, + fillValueTargetTensor); return success(); } @@ -4877,9 +6499,9 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); if (!outType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -4892,7 +6514,8 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { } // Not a tensor type. - auto selfType = dyn_cast(adaptor.getSelf().getType()); + auto self = adaptor.getSelf(); + auto selfType = dyn_cast(self.getType()); if (!selfType || !outType.hasStaticShape()) return rewriter.notifyMatchFailure( op, @@ -4924,8 +6547,13 @@ class ConvertAtenMaskedFillOp : public OpConversionPattern { RankedTensorType::get(rhsTensorType.getShape(), outElemTy), rhsTensor); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, rhsTensor) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + rewriter.replaceOpWithNewOp(op, outType, adaptor.getMask(), - rhsTensor, adaptor.getSelf()); + rhsTensor, self); return success(); } }; @@ -4949,9 +6577,9 @@ class ConvertAtenCloneOp : public OpConversionPattern { "unimplemented: only contiguous and channels last memory " "format is supported"); } - auto outType = OpConversionPattern::getTypeConverter() - ->convertType(op.getType()) - .template dyn_cast(); + auto outType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); rewriter.replaceOpWithNewOp(op, outType, adaptor.getSelf()); return success(); @@ -5003,12 +6631,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( translatePadsList.push_back(highPadding[i]); } - DenseElementsAttr paddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({rank, 2}, rewriter.getI64Type()), - translatePadsList); - - Value padsList1 = rewriter.create( - loc, paddingAttr.getType(), paddingAttr); + Value padsList1 = tosa::getTosaConstShape(rewriter, loc, translatePadsList); Value padValue = adaptor.getValue(); Operation *padOp = padValue.getDefiningOp(); @@ -5061,6 +6684,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto builtinTensors = getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType); + for (auto &tensor : builtinTensors) + tensor = tosa::promoteType(rewriter, tensor, outType); + auto result = tosa::CreateOpAndInfer( rewriter, loc, outType, builtinTensors, rewriter.getI32IntegerAttr(dim)); rewriter.replaceOp(op, result.getResult()); @@ -5079,8 +6705,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto resultType = typeConverter->convertType(op.getType()) - .template cast(); + auto resultType = + cast(typeConverter->convertType(op.getType())); auto elementType = resultType.getElementType(); if (isa(selfTy.getElementType())) { @@ -5092,276 +6718,2635 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneHalf = tosa::getConstTensor(rewriter, op, 0.5, {}, elementType).value(); + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + rewriter.replaceOpWithNewOp(op, resultType, self, oneHalf); return success(); } -} // namespace +template <> +LogicalResult +ConvertAtenOp::matchAndRewrite( + Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Converts torch.aten.__interpolate.size_list_scale_list to tosa.resize + auto input = adaptor.getInput(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) + return rewriter.notifyMatchFailure(op, + "Only Tensor types supported in TOSA"); + auto inputRank = inputTy.getRank(); + if (inputRank != 4) + return rewriter.notifyMatchFailure(op, + "TOSA resize() takes rank==4 tensors."); -// ----------------------------------------------------------------------------- -// TorchToTosa Pass -// ----------------------------------------------------------------------------- + auto inputShape = inputTy.getShape(); + auto inputElemTy = inputTy.getElementType(); + // TOSA works in NHWC. Perform the necessary transformations. + std::optional nchwToNhwcTransposeConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/{0, 2, 3, 1}, + /*shape=*/{static_cast(4)}); + SmallVector transposedInputShape( + {inputShape[0], inputShape[2], inputShape[3], inputShape[1]}); + auto transposedInputTy = RankedTensorType::get( + makeShapeLLVMCompatible(transposedInputShape), inputElemTy); + auto transposedInput = + rewriter + .create( + op->getLoc(), getTypeConverter()->convertType(transposedInputTy), + input, nchwToNhwcTransposeConst.value()) + .getResult(); -namespace { -class ConvertTorchToTosa : public ConvertTorchToTosaBase { -public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - TorchConversion::getBackendTypeConversionDependentDialects(registry); + auto inputHeight = transposedInputShape[1]; + auto inputWidth = transposedInputShape[2]; + + int outputHeight, outputWidth; + if (!isa(op.getScaleFactor().getType())) { + SmallVector scaleFactor; + if (!matchPattern(op.getScaleFactor(), + m_TorchListOfConstantFloats(scaleFactor))) + return rewriter.notifyMatchFailure( + op, "non-const scale_factor parameter unsupported"); + + outputHeight = inputHeight * scaleFactor[0]; + outputWidth = inputWidth * scaleFactor[1]; + + } else { + if (!isa(op.getSize().getType())) + return rewriter.notifyMatchFailure( + op, "Scale factor and size are both absent!"); + + SmallVector size; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) + return rewriter.notifyMatchFailure( + op, "non-const size parameter unsupported"); + outputHeight = size[0]; + outputWidth = size[1]; } - void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionTarget target(*context); - target.addLegalDialect(); + std::string pyMode; + if (!matchPattern(op.getMode(), m_TorchConstantStr(pyMode))) + return rewriter.notifyMatchFailure(op, + "non-const mode parameter unsupported"); - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); + // All torch modes listed in + // https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + if (pyMode != "bilinear" && pyMode != "nearest") + return rewriter.notifyMatchFailure( + op, "Only nearest and bilinear interpolation modes supported"); - // The following ops are never the primary reason why lowering fails. - // The backend contract only allows functions to return tensors thus there - // is always another op using them. - // When we have a chain of torch.constant.int followed by a unsupported - // torch op, we want the pass to mention the unsupported torch op - // in the error message. - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addIllegalDialect(); + std::string mode; + if (pyMode == "bilinear") { + mode = "BILINEAR"; + } else { + mode = "NEAREST_NEIGHBOR"; + } - RewritePatternSet patterns(context); + bool alignCorners; + if (!matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCorners))) + return rewriter.notifyMatchFailure( + op, "non-const align_corners parameter unsupported"); -#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context); - INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp) - INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp) -#undef INSERT_UNARY_FPONLY_PATTERN + bool recomputeScaleFactor; + if (isa(op.getRecomputeScaleFactor().getType())) + recomputeScaleFactor = false; + else if (!matchPattern(op.getRecomputeScaleFactor(), + m_TorchConstantBool(&recomputeScaleFactor))) + return rewriter.notifyMatchFailure( + op, "non-const recompute_scale_factor parameter unsupported"); + if (recomputeScaleFactor) + return rewriter.notifyMatchFailure( + op, "Application of recompute_scale_factor not yet supported"); -#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) - INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) - INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) - INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) - INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) -#undef INSERT_UNARY_PATTERN + bool antialias; + if (!matchPattern(op.getAntialias(), m_TorchConstantBool(&antialias))) + return rewriter.notifyMatchFailure( + op, "non-const antialias parameter unsupported"); + if (antialias) + return rewriter.notifyMatchFailure( + op, "Application of antialias not yet supported"); + + SmallVector transposedResizedOpShape( + {inputShape[0], outputHeight, outputWidth, inputShape[1]}); + auto transposedResizedOpTy = RankedTensorType::get( + makeShapeLLVMCompatible(transposedResizedOpShape), inputElemTy); + + // Formatting snake_case to match TOSA spec names for readability + int scale_y_n, scale_y_d, offset_y, border_y; + int scale_x_n, scale_x_d, offset_x, border_x; + + // Align corners sets the scaling ratio to (OH - 1)/(IH - 1) + // rather than OH / IH. Similarly for width. + auto normalize = [&](int input, int output, int &n, int &d, int &offset, + int &border) { + // Dimension is length 1, we are just sampling from one value. + if (input == 1) { + n = output; + d = 1; + offset = 0; + border = output - 1; + return; + } -#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) - INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) - INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) -#undef INSERT_BINARY_PATTERN + // Apply if aligned and capable to be aligned. + bool apply_aligned = alignCorners && (output > 1); + n = apply_aligned ? (output - 1) : output; + d = apply_aligned ? (input - 1) : input; -#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) -#undef INSERT_BINARY_ADDSUB_PATTERN + // Simplify the scalers, make sure they are even values. + int gcd = std::gcd(n, d); + n = 2 * n / gcd; + d = 2 * d / gcd; -#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) -#undef INSERT_BINARY_COMPARE_PATTERN + offset = 0; -#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); -#undef INSERT_BINARY_MUL_PATTERN + // If nearest neighbours we need to guarantee we round up. + if (mode == "NEAREST_NEIGHBOR" && alignCorners) { + offset += n / 2; + } -#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); -#undef INSERT_BINARY_DIV_PATTERN + // TBD: impact of antialias parameter here ? -#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context); - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, - mlir::tosa::convertReduceMeanOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, - mlir::tosa::convertReduceSumOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, - mlir::tosa::convertLinalgVectorNormOp) -#undef INSERT_NDIMS_REDUCTION_OP_PATTERN + // We can compute this directly based on previous values. + border = d * (output - 1) - n * (input - 1) + offset; + }; -#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context); - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, - mlir::tosa::convertReduceAnyOp) -#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN + normalize(inputHeight, outputHeight, scale_y_n, scale_y_d, offset_y, + border_y); + normalize(inputWidth, outputWidth, scale_x_n, scale_x_d, offset_x, border_x); -#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context); - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, - mlir::tosa::convertReduceAllOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, - mlir::tosa::convertReduceSumOp) -#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN + DenseI64ArrayAttr scale = rewriter.getDenseI64ArrayAttr( + {scale_y_n, scale_y_d, scale_x_n, scale_x_d}); + DenseI64ArrayAttr offset = + rewriter.getDenseI64ArrayAttr({offset_y, offset_x}); + DenseI64ArrayAttr border = + rewriter.getDenseI64ArrayAttr({border_y, border_x}); + StringAttr modeAttr = rewriter.getStringAttr(mode); -#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) -#undef INSERT_SQUEEZE_OP_PATTERN + auto resizeOpResult = + rewriter + .create(op->getLoc(), transposedResizedOpTy, + transposedInput, scale, offset, border, + modeAttr) + .getResult(); -#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); -#undef INSERT_MATMUL_ATEMOP_PATTERN + auto resultType = + cast(typeConverter->convertType(op.getType())); + std::optional nhwcToNchwTransposeConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/{0, 3, 1, 2}, + /*shape=*/{static_cast(4)}); -#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_MM_ATENOP_PATTERN(AtenMmOp); - INSERT_MM_ATENOP_PATTERN(AtenBmmOp); -#undef INSERT_MM_ATEMOP_PATTERN + rewriter + .replaceOpWithNewOp( + op, getTypeConverter()->convertType(resultType), resizeOpResult, + nhwcToNchwTransposeConst.value()) + .getResult(); -#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); -#undef INSERT_LINEAR_ATEMOP_PATTERN + return success(); +} -#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context); - INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, - tosa::AvgPool2dOp); -#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN +// Template to create supporting tril mask tensor for aten.tril +template +Value createTrilMask(PatternRewriter &rewriter, Operation *op, + ArrayRef shape, int64_t h, int64_t w, + int64_t diagonal) { + SmallVector vec; + + for (int64_t i = 0; i < h; i++) { + for (int64_t j = 0; j < w; j++) { + // Positive diagonal value includes as many diagonals above the main + // diagonal, while negative diagonal value excludes as many diagonals + // below the main diagonal. + if (i >= j - diagonal) { + vec.push_back(static_cast(1)); + } else { + vec.push_back(static_cast(0)); + } + } + } - target.addIllegalOp(); - patterns.add(typeConverter, context); + return tosa::getConstTensor(rewriter, op, vec, shape).value(); +} - target.addIllegalOp(); - patterns.add(typeConverter, context); +// Legalization for aten.tril +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTrilOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); -#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context); - INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); - INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); -#undef INSERT_CONSTANT_FILL_PATTERN + // Not a ranked tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); -#define INSERT_FILL_SCALAR_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); -#undef INSERT_FILL_SCALAR_PATTERN + // Rank below 2 not accepted + auto selfRank = selfType.getRank(); + if (selfRank <= 1) + return rewriter.notifyMatchFailure( + op, "Rank 0 and 1 are not accepted as they cause underflow"); -#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); -#undef INSERT_MASKED_FILL_PATTERN + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Currently only static shapes are supported"); -#define INSERT_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenTanhOp); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenSigmoidOp); - INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenLeakyReluOp); - INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenRsubScalarOp); - INSERT_ATENOP_PATTERN(AtenConvolutionOp); - INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); - INSERT_ATENOP_PATTERN(AtenReshapeOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); - INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); - INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); - INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenLog2Op); - INSERT_ATENOP_PATTERN(AtenThresholdOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenDropoutOp); - INSERT_ATENOP_PATTERN(AtenViewOp); - INSERT_ATENOP_PATTERN(AtenGeluOp); - INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); - INSERT_ATENOP_PATTERN(AtenEmbeddingOp); - INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenMaxDimOp); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenBroadcastToOp); - INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenAbsOp); - INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenLeTensorOp); - INSERT_ATENOP_PATTERN(AtenClampOp); - INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); - INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenCopyOp); - INSERT_ATENOP_PATTERN(AtenToDtypeOp); - INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenRemainderScalarOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenSqrtOp); - INSERT_ATENOP_PATTERN(AtenIscloseOp); -#undef INSERT_ATENOP_PATTERN + const TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); + if (!resultType) + return rewriter.notifyMatchFailure(op, "Result type cannot be empty"); + + // Get height, width of input tensor, and diagonal arg to create + // a const mask tensor to multiply with input. + // This mask tensor has the same height and width of input tensor + // and consists of 1's for the lower triangle part and 0's for the rest. + // For example, with h=4, w=6, diagonal=1: + // tensor([[1, 1, 0, 0, 0, 0], + // [1, 1, 1, 0, 0, 0], + // [1, 1, 1, 1, 0, 0], + // [1, 1, 1, 1, 1, 0]]) + auto selfShape = selfType.getShape(); + int64_t h = selfShape[selfRank - 2]; + int64_t w = selfShape[selfRank - 1]; + int64_t diagonal; + + if (!matchPattern(op.getDiagonal(), m_TorchConstantInt(&diagonal))) + return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer"); + + // Define shape for mask tensor based on rank + SmallVector maskShape; + for (auto i = 0; i < selfRank - 2; i++) + maskShape.push_back(1); + maskShape.push_back(h); + maskShape.push_back(w); + + Value trilMask = TypeSwitch(resultType.getElementType()) + .Case([&](auto) { + return createTrilMask(rewriter, op, maskShape, + h, w, diagonal); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return createTrilMask(rewriter, op, maskShape, + h, w, diagonal); + case 32: + return createTrilMask( + rewriter, op, maskShape, h, w, diagonal); + case 64: + return createTrilMask( + rewriter, op, maskShape, h, w, diagonal); + } + llvm_unreachable("Invalid integer width"); + }); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, trilMask) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); -#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, context); - INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); -#undef INSERT_CLONE_ATENOP_PATTERN + rewriter.replaceOpWithNewOp(op, resultType, self, trilMask, + /*shift=*/0); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); + return success(); +} + +// Legalization for aten.flip +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenFlipOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are currently supported"); + + SmallVector dims; + if (!matchPattern(adaptor.getDims(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure( + op, "Only constant dims are currently supported"); + + auto selfRank = selfTy.getRank(); + + auto resultTy = getTypeConverter()->convertType(op.getType()); + Value result = self; + + for (auto &dim : dims) { + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Not all dims are valid"); + + result = rewriter.create(op->getLoc(), resultTy, result, + static_cast(dim)); } -}; -} // namespace + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.round: +// Rounds elements of input to the nearest integer. +// Implements "round half to even" to break ties when a number is equidistant +// from two integers. +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRoundOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To round to the nearest integer, we will consider the fractional part of + // the input element (= input element - integer part of element). If the + // fractional part is smaller than 0.5, round the number down. If the + // fractional part is 0.5, apply "round half to even" rule. If the fractional + // part is greater than 0.5, round up. + // + // if (frac < 0.5 || (frac == 0.5 && floor(input) % 2 == 0)): + // res = floor(input) + // else: + // res = ceil(input) + + auto self = adaptor.getSelf(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) + return rewriter.notifyMatchFailure(op, "Only tensor types supported"); + + auto resultTy = + cast(getTypeConverter()->convertType(op.getType())); + + auto boolTy = + RankedTensorType::get(resultTy.getShape(), rewriter.getIntegerType(1)); + + auto resultElemTy = resultTy.getElementType(); + + auto oneHalf = + tosa::getConstTensor(rewriter, op, 0.5, {}, resultElemTy).value(); + + auto two = + tosa::getConstTensor(rewriter, op, 2, {}, resultElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, oneHalf) + .failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, two).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto floorInput = + rewriter.create(op->getLoc(), resultTy, self); + + // input - floor(input) + auto fractionalPart = rewriter.create( + op->getLoc(), resultTy, self, floorInput.getResult()); + + auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); + + auto floorInputDivByTwo = rewriter.create( + op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); + + auto floorDivResult = rewriter.create( + op->getLoc(), resultTy, floorInputDivByTwo.getResult()); + + // (floor(input) // 2) * 2 + auto evenComparison = rewriter.create( + op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0); + + // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 + auto floorInputEven = rewriter.create( + op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult()); + + auto fracEqualOneHalf = rewriter.create( + op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); + + auto fracLtOneHalf = rewriter.create( + op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); + + // (frac == 0.5) && (floor(input) % 2 == 0) + auto fracEqualOneHalfCond = rewriter.create( + op->getLoc(), boolTy, fracEqualOneHalf.getResult(), + floorInputEven.getResult()); + + // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) + auto floorResultCond = rewriter.create( + op->getLoc(), boolTy, fracLtOneHalf.getResult(), + fracEqualOneHalfCond.getResult()); + + rewriter.replaceOpWithNewOp( + op, resultTy, floorResultCond.getResult(), floorInput.getResult(), + ceilInput.getResult()); + + return success(); +} + +// Template to create supporting diagonal mask tensor for aten.diagonal +template +Value createDiagonalMask(PatternRewriter &rewriter, Operation *op, + ArrayRef shape, int64_t h, int64_t w, + int64_t offset) { + SmallVector vec; + + for (int64_t i = 0; i < h; i++) { + for (int64_t j = 0; j < w; j++) { + // Positive offset value moves above the main diagonal, while negative + // diagonal value moves below the main diagonal. + if (i + offset == j) { + vec.push_back(static_cast(1)); + } else { + vec.push_back(static_cast(0)); + } + } + } + + return tosa::getConstTensor(rewriter, op, vec, shape).value(); +} + +// Legalization for aten.diagonal +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenDiagonalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a ranked tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); + + // Rank below 2 not accepted + auto selfRank = selfType.getRank(); + if (selfRank <= 1) + return rewriter.notifyMatchFailure( + op, "Rank 0 and 1 are not accepted as they cause underflow"); + + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Currently only static shapes are supported"); + + const TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); + if (!resultType) + return rewriter.notifyMatchFailure(op, "Result type cannot be empty"); + + auto selfElemTy = selfType.getElementType(); + auto resultElemTy = resultType.getElementType(); + + int64_t offset, dim1, dim2; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + offset = 0; + + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) { + dim1 = 0; + } else { + dim1 = toPositiveDim(dim1, selfRank); + } + + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) { + dim2 = 1; + } else { + dim2 = toPositiveDim(dim2, selfRank); + } + + if (dim1 == dim2) + return rewriter.notifyMatchFailure(op, + "Values dim1 and dim2 cannot be equal"); + + auto selfShape = makeShapeTorchCompatible(selfType.getShape()); + int64_t h = selfShape[dim1]; + int64_t w = selfShape[dim2]; + + // Overflowing offset not supported + if ((offset < 0 && std::abs(offset) >= h) || (offset >= 0 && offset >= w)) + return rewriter.notifyMatchFailure( + op, "Offset greater or equal than shape not supported"); + + int64_t targetDim1 = selfRank - 2; + int64_t targetDim2 = selfRank - 1; + + Value selfTransposed = self; + SmallVector transposedInputShape = selfShape; + RankedTensorType transposedInputType = selfType; + + // If (dim1, dim2) != (rank - 2, rank - 1), transpose the input tensor + // so that dim1 and dim2 become rank - 2 and rank - 1. We do this so that + // we can consistently create the diagonal mask tensor. + if (!(dim1 == targetDim1 && dim2 == targetDim2)) { + SmallVector transposedDims; + transposedInputShape.clear(); + + for (int32_t i = 0; i < selfRank; ++i) { + if (i == dim1 || i == dim2) + continue; + transposedDims.push_back(i); + } + transposedDims.push_back(static_cast(dim1)); + transposedDims.push_back(static_cast(dim2)); + + auto transposedDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/transposedDims, + /*shape=*/{static_cast(selfRank)}); + + for (auto &dim : transposedDims) + transposedInputShape.push_back(selfShape[dim]); + + transposedInputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedInputShape), selfElemTy); + + selfTransposed = rewriter.create( + op->getLoc(), transposedInputType, self, transposedDimsConst.value()); + } + + // Define shape for mask tensor based on rank + SmallVector maskShape; + for (auto i = 0; i < selfRank - 2; i++) + maskShape.push_back(1); + maskShape.push_back(h); + maskShape.push_back(w); + + Value diagonalMask = + TypeSwitch(resultElemTy) + .Case([&](auto) { + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + case 32: + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + case 64: + return createDiagonalMask(rewriter, op, maskShape, h, w, + offset); + } + llvm_unreachable("Invalid integer width"); + }); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, diagonalMask) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + Value diagonalTensor = rewriter.create( + op->getLoc(), transposedInputType, selfTransposed, diagonalMask, + /*shift=*/0); + + auto resultShape = makeShapeTorchCompatible(resultType.getShape()); + auto targetReduceDim = resultShape[resultType.getRank() - 1]; + + // If transposedInputShape[targetDim1] (or h) is greater than the innermost + // dim of the result, we won't get the correct shape when we reduce sum along + // the innermost dim to get the result. Therefore, we have to slice the + // transposed tensor so that transposedInputShape[targetDim1] == + // targetReduceDim. + if (h > targetReduceDim) { + transposedInputShape[targetDim1] = targetReduceDim; + transposedInputType = RankedTensorType::get( + makeShapeLLVMCompatible(transposedInputShape), selfElemTy); + SmallVector startSlice(selfRank, 0); + SmallVector sizeSlice = + llvm::to_vector(makeShapeTorchCompatible(transposedInputShape)); + if (offset < 0) + startSlice[targetDim1] = std::abs(offset); + diagonalTensor = rewriter.create( + op->getLoc(), transposedInputType, diagonalTensor, + rewriter.getDenseI64ArrayAttr(startSlice), + rewriter.getDenseI64ArrayAttr(sizeSlice)); + } + + // Apply Reduce Sum to get the result + auto reduceDimType = RankedTensorType::get({1}, rewriter.getI64Type()); + auto reduceDimAttr = + DenseIntElementsAttr::get(reduceDimType, llvm::ArrayRef({targetDim2})); + auto result = + mlir::tosa::convertReduceSumOp(rewriter, op, resultType, diagonalTensor, + reduceDimAttr, /*keep_dims=*/false); + + rewriter.replaceOp(op, result.value()); + + return success(); +} + +// Legalization for aten.diag_embed +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenDiagEmbedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To perform diag_embed, we will apply scatter with a newly created diagonal + // index tensor over a constant zero tensor. + // To make it simpler, we will only scatter using the diagonal with respect + // to the two innermost dimensions, then permute the output tensor to the + // correct order of dimensions. + auto self = adaptor.getSelf(); + + // Not a ranked tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor types are supported"); + + auto selfRank = selfType.getRank(); + int64_t outRank = selfRank + 1; + + auto selfShape = makeShapeTorchCompatible(selfType.getShape()); + int64_t diagSize = selfShape[selfRank - 1]; + + if (!selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Currently only static shapes are supported"); + + const TypeConverter *typeConverter = this->getTypeConverter(); + RankedTensorType resultType = cast( + typeConverter->convertType(op->getResult(0).getType())); + if (!resultType) + return rewriter.notifyMatchFailure(op, "Result type cannot be empty"); + + auto selfElemTy = selfType.getElementType(); + auto resultElemTy = resultType.getElementType(); + + int64_t offset{0}; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + return rewriter.notifyMatchFailure(op, + "Offset value should be a constant int"); + + // dim1 default is -2 + int64_t dim1{outRank - 2}; + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, + "Dim1 value should be a constant int"); + dim1 = toPositiveDim(dim1, outRank); + + // dim2 default is -1 + int64_t dim2{outRank - 1}; + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) + return rewriter.notifyMatchFailure(op, + "Dim2 value should be a constant int"); + dim2 = toPositiveDim(dim2, outRank); + + if (dim1 == dim2) + return rewriter.notifyMatchFailure(op, "Dim1 and dim2 cannot be equal"); + + // If offset is smaller than 0, we will swap dim1 and dim2 and convert offset + // to a positive value + if (offset < 0) { + std::swap(dim1, dim2); + offset = std::abs(offset); + } + + // Create the diagonal index tensor + int64_t repeat = 1; + for (int64_t i = 0; i < selfRank - 1; i++) + repeat *= selfShape[i]; + + SmallVector indexVec; + for (int32_t i = 0; i < repeat; i++) { + for (int32_t j = offset; j < diagSize + offset; j++) + indexVec.push_back(j); + } + + SmallVector indexShape = llvm::to_vector(selfShape); + indexShape.push_back(1); + + auto index = tosa::getConstTensor(rewriter, op, + /*vec=*/indexVec, + /*shape=*/indexShape) + .value(); + + // Reshape the input tensor to be the same shape as the new index tensor to + // act as the src for scattering + auto scatterSrc = rewriter.create( + op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(indexShape), selfElemTy), + self, rewriter.getDenseI64ArrayAttr(indexShape)); + + // Create a const zero tensor to scatter the input onto + SmallVector zeroShape; + for (int64_t i = 0; i < selfRank - 1; i++) + zeroShape.push_back(selfShape[i]); + zeroShape.push_back(diagSize + offset); + zeroShape.push_back(diagSize + offset); + + int64_t numElemOfZeroTensor = 1; + for (int64_t &d : zeroShape) + numElemOfZeroTensor *= d; + + Value zero = + TypeSwitch(selfElemTy) + .Case([&](auto) { + return tosa::getConstTensor( + rewriter, op, SmallVector(numElemOfZeroTensor, 0), + zeroShape) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return tosa::getConstTensor( + rewriter, op, + SmallVector(numElemOfZeroTensor, 0), zeroShape) + .value(); + case 32: + return tosa::getConstTensor( + rewriter, op, + SmallVector(numElemOfZeroTensor, 0), + zeroShape) + .value(); + case 64: + return tosa::getConstTensor( + rewriter, op, + SmallVector(numElemOfZeroTensor, 0), + zeroShape) + .value(); + } + llvm_unreachable("Invalid integer width"); + }); + + // Convert PyTorch index and dim to TensorFlow-style indices + auto indicesTf = tosa::convertTorchIndexToTfIndices(rewriter, op, zero, index, + outRank - 1); + if (!indicesTf) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch index and dim to TensorFlow indices failed"); + + // Perform the TensorFlow ScatterNd algorithm with TensorFlow-style indices as + // input + auto diagonalTensor = tosa::convertScatterNdOp( + rewriter, op, + RankedTensorType::get(makeShapeTorchCompatible(zeroShape), resultElemTy), + zero, indicesTf.value(), scatterSrc.getResult()); + if (!diagonalTensor) + return rewriter.notifyMatchFailure(op, "Convert ScatterNdOp failed"); + + // Create the final dims order to permute the scattered tensor + SmallVector permutedDims(outRank, 0); + int32_t currentDim = 0; + int32_t i = 0; + + while (i < outRank) { + if (i == dim1) { + permutedDims[i] = outRank - 2; + i++; + continue; + } + + if (i == dim2) { + permutedDims[i] = outRank - 1; + i++; + continue; + } + + permutedDims[i] = currentDim; + currentDim++; + i++; + } + + auto permutedDimsConst = + tosa::getConstTensor(rewriter, op, + /*vec=*/permutedDims, + /*shape=*/{static_cast(outRank)}); + + auto result = rewriter.create(op->getLoc(), resultType, + diagonalTensor.value(), + permutedDimsConst.value()); + + rewriter.replaceOp(op, result.getResult()); + + return success(); +} + +// Legalization for aten.uniform +// Since TOSA hasn't got a built-in random generator yet, we will use +// std::uniform_real_distribution with the std::default_random_engine from C++ +// library +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUniformOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + + auto generator = adaptor.getGenerator(); + if (!isa(generator.getType())) + return rewriter.notifyMatchFailure(op, + "Custom generators are not supported"); + + double fromDouble{0.0}, toDouble{1.0}; + auto isFloat = + matchPattern(op.getFrom(), m_TorchConstantFloat(&fromDouble)) && + matchPattern(op.getTo(), m_TorchConstantFloat(&toDouble)); + + int64_t fromInt{0}, toInt{1}; + auto isInt = matchPattern(op.getFrom(), m_TorchConstantInt(&fromInt)) && + matchPattern(op.getTo(), m_TorchConstantInt(&toInt)); + + if (!isFloat && !isInt) + return rewriter.notifyMatchFailure( + op, "From and To values are not constant values"); + + int64_t numElem = 1; + for (int64_t i = 0; i < selfType.getRank(); i++) + numElem *= selfShape[i]; + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + std::default_random_engine gen; + + auto from = isFloat ? fromDouble : fromInt; + auto to = isFloat ? toDouble : toInt; + + std::uniform_real_distribution uniformDist(from, to); + SmallVector uniformVec; + + for (int64_t i = 0; i < numElem; i++) + uniformVec.push_back(uniformDist(gen)); + + auto result = tosa::getConstTensor(rewriter, op, uniformVec, selfShape, + selfType.getElementType()) + .value(); + + result = tosa::promoteType(rewriter, result, resultType); + + rewriter.replaceOp(op, {result}); + + return success(); +} + +// Legalization for aten.threshold_backward +// result = self <= threshold ? 0 : grad +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenThresholdBackwardOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + + auto selfShape = selfType.getShape(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + Value threshold; + if (failed(torchScalarToTosaTensor(rewriter, op, op.getThreshold(), threshold, + selfElemTy, selfShape))) + return rewriter.notifyMatchFailure(op, + "Threshold must be a constant scalar"); + + auto grad = adaptor.getGradOutput(); + + // Not a tensor type + auto gradType = dyn_cast(grad.getType()); + if (!gradType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + Value zero = + TypeSwitch(resultElemTy) + .Case([&](auto) { + return tosa::getConstTensor(rewriter, op, 0, {}, + resultElemTy) + .value(); + }) + .Case([&](auto intType) { + switch (intType.getWidth()) { + case 1: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + case 8: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + case 32: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + case 64: + return tosa::getConstTensor(rewriter, op, 0, {}).value(); + } + llvm_unreachable("Invalid integer width"); + }); + + // Check: input <= threshold + auto cond = rewriter.create( + op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()), + threshold, self); + + self = tosa::promoteType(rewriter, self, resultType); + grad = tosa::promoteType(rewriter, grad, resultType); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, zero).failed() || + mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, grad).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto result = rewriter.create(op->getLoc(), resultType, + cond.getResult(), zero, grad); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for aten.as_strided +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenAsStridedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // To lower aten.as_strided to TOSA, we will first reshape the input tensor to + // an 1-D tensor, then calculate the indices of result elements based on the + // output size, stride and storage offset. With the reshaped 1-D tensor and + // the indices, we can apply Gather to extract the required elements into a + // new tensor and then reshape it back to the desired output shape. + auto self = adaptor.getSelf(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + auto selfElemTy = selfType.getElementType(); + auto selfShape = selfType.getShape(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + // Get output size + SmallVector outputSize; + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outputSize))) + return rewriter.notifyMatchFailure( + op, "Only a constant list form of output size is supported"); + + // Get stride + SmallVector stride; + if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride))) + return rewriter.notifyMatchFailure( + op, "Only a constant list form of stride is supported"); + + // Get storage offset + int64_t offset; + if (!matchPattern(op.getStorageOffset(), m_TorchConstantInt(&offset))) + offset = 0; + + // Reshape input tensor into an 1-D tensor + int64_t selfNumElems = std::accumulate(selfShape.begin(), selfShape.end(), 1, + std::multiplies()); + + auto self1D = rewriter.create( + op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self, + rewriter.getDenseI64ArrayAttr({selfNumElems})); + + // Calculate the target elements indices + SmallVector targetIndicesVec; + int64_t outputRank = outputSize.size(); + int64_t outputNumElems = std::accumulate(outputSize.begin(), outputSize.end(), + 1, std::multiplies()); + + for (int64_t i = 0; i < outputNumElems; i++) { + // Index formula: + // index[i] = coord_i_0 * stride[0] + coord_i_1 * stride[1] + ... + + // coord_i_n * stride[n] + int32_t index = offset; + int64_t coordFinder = i; + for (int64_t dim = 0; dim < outputRank; dim++) { + int64_t indexCoord = coordFinder % outputSize[outputRank - dim - 1]; + index += indexCoord * stride[outputRank - dim - 1]; + coordFinder /= outputSize[outputRank - dim - 1]; + } + targetIndicesVec.push_back(index); + } + + auto targetIndices = + tosa::getConstTensor(rewriter, op, targetIndicesVec, + makeShapeTorchCompatible({outputNumElems})) + .value(); + + // Convert PyTorch-style indices and dim into TensorFlow-style indices + auto targetIndicesTf = tosa::convertTorchIndexToTfIndices( + rewriter, op, self1D.getResult(), targetIndices, 0); + if (!targetIndicesTf) + return rewriter.notifyMatchFailure(op, + "Convert PyTorch-style indices and dim " + "to TensorFlow-style indices failed"); + + // Gather the target elements from 1-D input tensor + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve the + // target elements + auto gatherOp = tosa::convertGatherNdOp( + rewriter, op, + RankedTensorType::get(makeShapeTorchCompatible({outputNumElems}), + resultElemTy), + self1D.getResult(), targetIndicesTf.value()); + + if (!gatherOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + auto result = rewriter.create( + op->getLoc(), resultType, gatherOp.value(), + rewriter.getDenseI64ArrayAttr(outputSize)); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for torch.prims.collapse +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsCollapseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getA(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + int64_t start, end; + if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + return rewriter.notifyMatchFailure( + op, "Only constant int start value is supported"); + + if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + return rewriter.notifyMatchFailure( + op, "Only constant int end value is supported"); + + // Identity case + if (start == end) { + rewriter.replaceOp(op, self); + return success(); + } + + // Technically, I should calculate the output shape based on the input shape, + // start value, and end value. However, that would just give the same result + // as me taking the result shape straight from resultType and applying + // tosa::ReshapeOp to the input. Therefore, I'm opting for the latter approach + // here, which is more simple and quicker. + rewriter.replaceOpWithNewOp( + op, resultType, self, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + + return success(); +} + +Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, + int64_t paddingAxisLeft, int64_t paddingAxisRight, + int64_t axis, TensorType resultType, Location loc, + ConversionPatternRewriter &rewriter) { + + SmallVector resultTensors; + auto resultShape = resultType.getShape(); + + auto inputType = dyn_cast(input.getType()); + auto inputRank = inputType.getRank(); + auto inputElemTy = inputType.getElementType(); + + assert(inputRank == resultType.getRank()); + int64_t axisOffset = inputRank - axis - 1; + + // Use tosa.slice and tosa.reverse to get the reflection pads based on the + // padding size + if (paddingAxisLeft > 0) { + SmallVector leftStartSlice(inputRank, 0); + SmallVector leftSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - axisOffset); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + leftSizeSlice.push_back(resultShape[inputRank - iDim - 1]); + } + + leftStartSlice[axis] = 1; + leftSizeSlice[axis] = paddingAxisLeft; + + SmallVector leftPadShape(unpaddedShape.begin(), + unpaddedShape.end() - (axisOffset + 1)); + leftPadShape.push_back(paddingAxisLeft); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + leftPadShape.push_back(resultShape[inputRank - iDim - 1]); + } + + auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); + + auto leftPadSlice = rewriter.create( + loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + auto leftPad = rewriter.create( + loc, leftPadType, leftPadSlice.getResult(), static_cast(axis)); + + resultTensors.push_back(leftPad.getResult()); + } + + resultTensors.push_back(input); + + if (paddingAxisRight > 0) { + SmallVector rightStartSlice(inputRank, 0); + SmallVector rightSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - axisOffset); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + rightSizeSlice.push_back(resultShape[inputRank - iDim - 1]); + } + + rightStartSlice[axis] = unpaddedShape[axis] - paddingAxisRight - 1; + rightSizeSlice[axis] = paddingAxisRight; + + SmallVector rightPadShape(unpaddedShape.begin(), + unpaddedShape.end() - (axisOffset + 1)); + rightPadShape.push_back(paddingAxisRight); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + rightPadShape.push_back(resultShape[inputRank - iDim - 1]); + } + + auto rightPadType = RankedTensorType::get(rightPadShape, inputElemTy); + + auto rightPadSlice = rewriter.create( + loc, rightPadType, input, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + auto rightPad = rewriter.create( + loc, rightPadType, rightPadSlice.getResult(), + static_cast(axis)); + + resultTensors.push_back(rightPad.getResult()); + } + + return tosa::CreateOpAndInfer(rewriter, loc, resultType, + resultTensors, axis); +} + +// Legalization for aten.reflection_pad1d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1]) + return rewriter.notifyMatchFailure( + op, "Padding should be less than input boundary size"); + + // Identity case + if (paddingLeft == 0 && paddingRight == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + auto result = + reflectionPadAlongAxis(self, selfShape, paddingLeft, paddingRight, + selfRank - 1, resultType, op->getLoc(), rewriter); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.reflection_pad2d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1] || + paddingTop >= selfShape[selfRank - 2] || + paddingBottom >= selfShape[selfRank - 2]) + return rewriter.notifyMatchFailure( + op, "Padding must be less than the corresponding input dimension"); + + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + SmallVector selfSidePaddedShape(selfShape.begin(), + selfShape.end() - 1); + selfSidePaddedShape.push_back(resultShape.back()); + + auto selfSidePadded = reflectionPadAlongAxis( + self, selfShape, paddingLeft, paddingRight, selfRank - 1, + RankedTensorType::get(selfSidePaddedShape, selfElemTy), op->getLoc(), + rewriter); + + auto result = reflectionPadAlongAxis(selfSidePadded, selfShape, paddingTop, + paddingBottom, selfRank - 2, resultType, + op->getLoc(), rewriter); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.reflection_pad3d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad3dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + int64_t paddingFront = paddingList[4]; + int64_t paddingBack = paddingList[5]; + + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1] || + paddingTop >= selfShape[selfRank - 2] || + paddingBottom >= selfShape[selfRank - 2] || + paddingFront >= selfShape[selfRank - 3] || + paddingBack >= selfShape[selfRank - 3]) + return rewriter.notifyMatchFailure( + op, "Padding must be less than the corresponding input dimension"); + + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0 && paddingFront == 0 && paddingBack == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + SmallVector self1dPaddedShape(selfShape.begin(), + selfShape.end() - 1); + self1dPaddedShape.push_back(resultShape.back()); + + auto self1dPadded = reflectionPadAlongAxis( + self, selfShape, paddingLeft, paddingRight, selfRank - 1, + RankedTensorType::get(self1dPaddedShape, selfElemTy), op->getLoc(), + rewriter); + + SmallVector self2dPaddedShape(selfShape.begin(), + selfShape.end() - 2); + self2dPaddedShape.push_back(resultShape[resultShape.size() - 2]); + self2dPaddedShape.push_back(resultShape.back()); + + auto self2dPadded = reflectionPadAlongAxis( + self1dPadded, selfShape, paddingTop, paddingBottom, selfRank - 2, + RankedTensorType::get(self2dPaddedShape, selfElemTy), op->getLoc(), + rewriter); + + auto result = + reflectionPadAlongAxis(self2dPadded, selfShape, paddingFront, paddingBack, + selfRank - 3, resultType, op->getLoc(), rewriter); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.replication_pad2d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReplicationPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + // Use tosa.slice to get the reflection pads based on the padding size + SmallVector sideTensors; + + if (paddingLeft > 0) { + SmallVector leftStartSlice(selfRank, 0); + SmallVector leftSizeSlice(selfShape); + + leftStartSlice[selfRank - 1] = 0; + leftSizeSlice[selfRank - 1] = 1; + + SmallVector leftPadSliceShape(selfShape.begin(), + selfShape.end() - 1); + leftPadSliceShape.push_back(1); + + auto leftPadSliceType = + RankedTensorType::get(leftPadSliceShape, selfElemTy); + + auto leftPadSlice = rewriter.create( + op->getLoc(), leftPadSliceType, self, + rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + for (int64_t i = 0; i < paddingLeft; i++) + sideTensors.push_back(leftPadSlice.getResult()); + } + + sideTensors.push_back(self); + + if (paddingRight > 0) { + SmallVector rightStartSlice(selfRank, 0); + SmallVector rightSizeSlice(selfShape); + + rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - 1; + rightSizeSlice[selfRank - 1] = 1; + + SmallVector rightPadSliceShape(selfShape.begin(), + selfShape.end() - 1); + rightPadSliceShape.push_back(1); + + auto rightPadSliceType = + RankedTensorType::get(rightPadSliceShape, selfElemTy); + + auto rightPadSlice = rewriter.create( + op->getLoc(), rightPadSliceType, self, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + for (int64_t i = 0; i < paddingRight; i++) + sideTensors.push_back(rightPadSlice.getResult()); + } + + SmallVector selfSidePaddedShape(selfShape.begin(), + selfShape.end() - 1); + selfSidePaddedShape.push_back(resultShape.back()); + + auto selfSidePadded = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, + selfRank - 1); + + SmallVector resultTensors; + + if (paddingTop > 0) { + SmallVector topStartSlice(selfRank, 0); + SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); + topSizeSlice.push_back(resultShape.back()); + + topStartSlice[selfRank - 2] = 0; + topSizeSlice[selfRank - 2] = 1; + + SmallVector topPadSliceShape(selfShape.begin(), + selfShape.end() - 2); + topPadSliceShape.push_back(1); + topPadSliceShape.push_back(resultShape.back()); + + auto topPadSliceType = RankedTensorType::get(topPadSliceShape, selfElemTy); + + auto topPadSlice = rewriter.create( + op->getLoc(), topPadSliceType, selfSidePadded, + rewriter.getDenseI64ArrayAttr(topStartSlice), + rewriter.getDenseI64ArrayAttr(topSizeSlice)); + + for (int64_t i = 0; i < paddingTop; i++) + resultTensors.push_back(topPadSlice.getResult()); + } + + resultTensors.push_back(selfSidePadded.getResult()); + + if (paddingBottom > 0) { + SmallVector bottomStartSlice(selfRank, 0); + SmallVector bottomSizeSlice(selfShape.begin(), + selfShape.end() - 1); + bottomSizeSlice.push_back(resultShape.back()); + + bottomStartSlice[selfRank - 2] = selfShape[selfRank - 2] - 1; + bottomSizeSlice[selfRank - 2] = 1; + + SmallVector bottomPadSliceShape(selfShape.begin(), + selfShape.end() - 2); + bottomPadSliceShape.push_back(1); + bottomPadSliceShape.push_back(resultShape.back()); + + auto bottomPadSliceType = + RankedTensorType::get(bottomPadSliceShape, selfElemTy); + + auto bottomPadSlice = rewriter.create( + op->getLoc(), bottomPadSliceType, selfSidePadded, + rewriter.getDenseI64ArrayAttr(bottomStartSlice), + rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); + + for (int64_t i = 0; i < paddingBottom; i++) + resultTensors.push_back(bottomPadSlice.getResult()); + } + + auto result = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for torch.prims.split_dim +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + PrimsSplitDimOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getA(); + + // Not a tensor type + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + int64_t dim, outerLength; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "Only constant int dim value is supported"); + + auto selfRank = selfType.getRank(); + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Dim is invalid"); + + if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength))) + return rewriter.notifyMatchFailure( + op, "Only constant int outer length value is supported"); + + // Technically, I should calculate the output shape based on the dim and + // outer length values. However, that would just give the same result as me + // taking the result shape straight from resultType and applying + // tosa::ReshapeOp to the input. Therefore, I'm opting for the latter + // approach here, which is more simple and quicker. + rewriter.replaceOpWithNewOp( + op, resultType, self, + rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); + + return success(); +} + +// Legalization for aten.outer +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenOuterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + if (selfType.getRank() != 1) + return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported"); + + auto vec2 = adaptor.getVec2(); + + auto vec2Type = dyn_cast(vec2.getType()); + if (!vec2Type) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + if (vec2Type.getRank() != 1) + return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + self = tosa::promoteType(rewriter, self, resultType); + vec2 = tosa::promoteType(rewriter, vec2, resultType); + + SmallVector resultShapeIndex1Replaced({resultShape[0], 1}); + SmallVector resultShapeIndex0Replaced({1, resultShape[1]}); + + // Reshape and tile self to shape {selfShape[0], resultShape[1]} + auto selfReshaped = rewriter.create( + op->getLoc(), + RankedTensorType::get(resultShapeIndex1Replaced, + resultType.getElementType()), + self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + + auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex0Replaced); + + auto selfTiled = rewriter.create( + op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples); + + // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]} + auto vec2Reshaped = rewriter.create( + op->getLoc(), + RankedTensorType::get(resultShapeIndex0Replaced, + resultType.getElementType()), + vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + + auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex1Replaced); + + auto vec2Tiled = rewriter.create( + op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples); + + auto result = + tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(), + vec2Tiled.getResult(), /*shift=*/0); + + rewriter.replaceOp(op, result); + return success(); +} + +// Legalization for aten.upsample_nearest2d +template +class ConvertUpsampleNearest2dForward : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // aten.upsample_nearest2d lowering process: + // 1. Reshape input: (N, C, H, W) -> (N, C, H x W) + // 2. Calculate PyTorch-styled gather op indices based on the following + // formula (based on Torch to Linalg UpsampleNearest2d lowering formula): + // for i in range(N x C): + // for heightIndex in range(scaledHeight): + // for widthIndex in range(scaledWidth): + // indices.append(int(heightIndex // scalesH * selfWidth + + // widthIndex // scalesW)) + // 3. Convert PyTorch-styled indices to TensorFlow-styled indices + // 4. Apply TensorFlow-styled ConverGatherOpNd to retrieve the output + // 5. Reshape output to desired output shape + Value self; + if constexpr (std::is_same()) { + self = adaptor.getSelf(); + } else if constexpr (std::is_same()) { + self = adaptor.getInput(); + } else { + return rewriter.notifyMatchFailure( + op, "Expected either AtenUpsampleNearest2dOp or " + "AtenUpsampleNearest2dVecOp"); + } + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto selfHeight = selfShape[selfRank - 2]; + auto selfWidth = selfShape[selfRank - 1]; + + auto resultType = dyn_cast( + OpConversionPattern::getTypeConverter()->convertType( + op.getType())); + auto resultShape = resultType.getShape(); + auto resultElemTy = resultType.getElementType(); + + // Get op's parameters + SmallVector outputSize; + SmallVector scaleFactors; + double scalesH; + double scalesW; + int64_t outputHeight; + int64_t outputWidth; + if constexpr (std::is_same()) { + if (!matchPattern(op.getOutputSize(), + m_TorchListOfConstantInts(outputSize))) + return rewriter.notifyMatchFailure( + op, "Non-constant output size not supported"); + + outputHeight = outputSize[0]; + outputWidth = outputSize[1]; + + if (isa(op.getScalesH().getType())) { + scalesH = + static_cast(outputHeight) / static_cast(selfHeight); + } else { + if (!matchPattern(op.getScalesH(), m_TorchConstantFloat(&scalesH))) + return rewriter.notifyMatchFailure( + op, "Non-constant height scales not supported"); + + scalesH = std::ceil(scalesH); + } + + if (isa(op.getScalesW().getType())) { + scalesW = + static_cast(outputWidth) / static_cast(selfWidth); + } else { + if (!matchPattern(op.getScalesW(), m_TorchConstantFloat(&scalesW))) + return rewriter.notifyMatchFailure( + op, "Non-constant width scales not supported"); + + scalesW = std::ceil(scalesW); + } + } else if constexpr (std::is_same()) { + auto isOutputSizeNone = + isa(op.getOutputSize().getType()); + auto isScaleFactorsNone = + isa(op.getScaleFactors().getType()); + + if ((isOutputSizeNone && isScaleFactorsNone) || + (!isOutputSizeNone && !isScaleFactorsNone)) + return rewriter.notifyMatchFailure( + op, "Must specify exactly one of output size and scale factors"); + + if (!isOutputSizeNone) { + if (!matchPattern(op.getOutputSize(), + m_TorchListOfConstantInts(outputSize))) + return rewriter.notifyMatchFailure( + op, "Non-constant output size not supported"); + + outputHeight = outputSize[0]; + outputWidth = outputSize[1]; + + // Output size values being provided implies that scale values are not + // provided + scalesH = + static_cast(outputHeight) / static_cast(selfHeight); + scalesW = + static_cast(outputWidth) / static_cast(selfWidth); + } else { + if (!matchPattern(op.getScaleFactors(), + m_TorchListOfConstantFloats(scaleFactors))) + return rewriter.notifyMatchFailure( + op, "Non-constant output size not supported"); + + scalesH = std::ceil(scaleFactors[0]); + scalesW = std::ceil(scaleFactors[1]); + + // Scale values being provided implies that output size values are not + // provided + outputHeight = static_cast(scalesH * selfHeight); + outputWidth = static_cast(scalesW * selfWidth); + } + } + + // Reshape input + SmallVector reshapedSelfShape(selfShape.begin(), + selfShape.end() - 2); + reshapedSelfShape.push_back(selfHeight * selfWidth); + + auto reshapedSelf = rewriter.create( + op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy), + self, rewriter.getDenseI64ArrayAttr(reshapedSelfShape)); + + // Calculate PyTorch-styled gather indices + SmallVector targetIndicesVec; + int64_t indexRepeat = std::accumulate( + selfShape.begin(), selfShape.end() - 2, 1, std::multiplies()); + for (int64_t i = 0; i < indexRepeat; i++) { + for (int64_t heightIndex = 0; heightIndex < outputHeight; heightIndex++) { + for (int64_t widthIndex = 0; widthIndex < outputWidth; widthIndex++) { + targetIndicesVec.push_back(static_cast( + std::floor(heightIndex / scalesH) * selfWidth + + std::floor(widthIndex / scalesW))); + } + } + } + + SmallVector targetIndicesShape(selfShape.begin(), + selfShape.end() - 2); + targetIndicesShape.push_back(outputHeight * outputWidth); + auto targetIndicesTorch = + tosa::getConstTensor(rewriter, op, targetIndicesVec, + targetIndicesShape) + .value(); + + // Convert PyTorch-styled indices to TensorFlow-styled indices + auto targetIndicesTF = tosa::convertTorchIndexToTfIndices( + rewriter, op, reshapedSelf.getResult(), targetIndicesTorch, + selfRank - 2); + if (!targetIndicesTF) + return rewriter.notifyMatchFailure( + op, "Convert PyTorch-styled indices and dim " + "to TensorFlow-styled indices failed"); + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + auto gatherOp = tosa::convertGatherNdOp( + rewriter, op, RankedTensorType::get(targetIndicesShape, resultElemTy), + reshapedSelf.getResult(), targetIndicesTF.value()); + if (!gatherOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + auto result = rewriter.create( + op->getLoc(), resultType, gatherOp.value(), + rewriter.getDenseI64ArrayAttr(resultShape)); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); + } +}; + +// Legalization for aten.logit +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLogitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Logit formula: + // result = log(zi / (1 - zi)) + // Where: if eps is not None: + // zi = input clampled to [eps, 1 - eps] + // else: + // zi = input + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + if (!isa(resultElemTy)) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // If input is not a float type then cast it to result element type + auto selfElemTy = selfType.getElementType(); + if (!isa(selfElemTy)) + self = tosa::promoteType(rewriter, self, resultType); + + bool isEpsNone = isa(op.getEps().getType()); + + double eps; + if (!isEpsNone && !matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) + return rewriter.notifyMatchFailure(op, + "Non-const eps value is not supported"); + + auto zi = self; + + // Clamp input to [eps, 1 - eps] when eps is not None + // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp + if (!isEpsNone) { + zi = rewriter + .create( + op->getLoc(), resultType, self, + rewriter.getI64IntegerAttr(static_cast(eps)), + rewriter.getI64IntegerAttr(static_cast(1 - eps)), + rewriter.getF32FloatAttr(static_cast(eps)), + rewriter.getF32FloatAttr(static_cast(1 - eps)), + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")) + .getResult(); + } + + auto one = + tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, one).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto oneMinusZi = + rewriter.create(op->getLoc(), resultType, one, zi); + + auto oneMinusZiReciprocal = rewriter.create( + op->getLoc(), resultType, oneMinusZi.getResult()); + + auto mulOp = rewriter.create(op->getLoc(), resultType, zi, + oneMinusZiReciprocal.getResult(), + /*shift=*/0); + + auto result = + rewriter.create(op->getLoc(), resultType, mulOp.getResult()); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for aten.log1p +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog1pOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // log1p formula: + // yi = log(xi + 1) + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + if (!isa(resultElemTy)) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // If input is not a float type then cast it to result element type + auto selfElemTy = selfType.getElementType(); + if (!isa(selfElemTy)) + self = tosa::promoteType(rewriter, self, resultType); + + auto one = + tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, one).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto addOp = + rewriter.create(op->getLoc(), resultType, self, one); + + auto result = + rewriter.create(op->getLoc(), resultType, addOp.getResult()); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for aten.log10 +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLog10Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // log10 formula (using log base changing formula since TOSA doesn't have a + // builtin log10 op): + // yi = log(xi) / log(10) + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + if (!isa(resultElemTy)) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // If input is not a float type then cast it to result element type + auto selfElemTy = selfType.getElementType(); + if (!isa(selfElemTy)) + self = tosa::promoteType(rewriter, self, resultType); + + auto ten = tosa::getConstTensor(rewriter, op, 10.0f, {}, resultElemTy) + .value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, ten).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto logOfSelf = rewriter.create(op->getLoc(), resultType, self); + + auto constTenType = RankedTensorType::get( + dyn_cast(ten.getType()).getShape(), resultElemTy); + + auto logOfTen = rewriter.create(op->getLoc(), constTenType, ten); + + auto reciprocalOp = rewriter.create( + op->getLoc(), constTenType, logOfTen.getResult()); + + auto result = rewriter.create( + op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(), + /*shift=*/0); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for aten.expm1 +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenExpm1Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // expm1 formula: + // yi = exp(x) - 1 + // Note: This lowering might not provide as great precision as aten.expm1 + // since TOSA doesn't have a built-in expm1 op. + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + if (!isa(resultElemTy)) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // If input is not a float type then cast it to result element type + auto selfElemTy = selfType.getElementType(); + if (!isa(selfElemTy)) + self = tosa::promoteType(rewriter, self, resultType); + + auto one = + tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), self, one).failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto expOp = rewriter.create(op->getLoc(), resultType, self); + + auto result = rewriter.create(op->getLoc(), resultType, + expOp.getResult(), one); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for aten.tan +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenTanOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // tan = sin / cos + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + + if (!isa(resultType.getElementType())) + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype result types are supported"); + + // Non floating point inputs are not supported in TOSA so we cast the input + // to result type + if (!isa(selfType.getElementType())) + self = tosa::promoteType(rewriter, self, resultType); + + auto sinOp = rewriter.create(op->getLoc(), resultType, self); + + auto cosOp = rewriter.create(op->getLoc(), resultType, self); + + auto reciprocalOp = + rewriter.create(op->getLoc(), resultType, cosOp); + + auto result = rewriter.create( + op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(), + /*shift=*/0); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +// Legalization for aten.unfold +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnfoldOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Approach: Use GatherOp to retrieve target elements from target dim and then + // reshape the output into slices according to the output shape + // + // Lowering steps: + // 1. Create PyTorch-style indices tensor corresponding to target elements and + // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1)) + // with d_x being the dimension size of the input at dim x. + // The indices vector will be calculated using the following formula: + // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)): + // for window in range(nWindows): + // for elementIndex in range(size): + // for j in range(d_(target_dim + 1) * ... * d_(rank-1)): + // indices_vec.push_back(elementIndex + window * step) + // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices + // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + // 4. Reshape result from above to correct output shape + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultElemTy = resultType.getElementType(); + + int64_t dim; + if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, + "Only constant int dims are supported"); + + int64_t size; + if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) + return rewriter.notifyMatchFailure(op, + "Only constant int sizes are supported"); + + int64_t step; + if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + return rewriter.notifyMatchFailure(op, + "Only constant int steps are supported"); + + if (step <= 0) + return rewriter.notifyMatchFailure(op, "Step value must be greater than 0"); + + // Handle rank zero + if (selfRank == 0) { + if (dim != 0) + return rewriter.notifyMatchFailure( + op, "Unsupported dim value for rank zero input"); + + if (size != 1) + return rewriter.notifyMatchFailure( + op, "Unsupported size value for rank zero input"); + + auto result = rewriter.create( + op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, + rewriter.getDenseI64ArrayAttr({1})); + + rewriter.replaceOp(op, {result.getResult()}); + return success(); + } + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "Dim value is invalid"); + + // Size of dimension 'dim' in the returned tensor (or number of windows within + // the dimension that got sliced) + int64_t nWindows = (selfShape[dim] - size) / step + 1; + + // Find number of times that each base index value gets repeated for target + // dim based on dim values before and after target dim i.e. preDimAccumulate = + // d_0 * d_1 * ... * d_(target_dim - 1) + // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1) + int64_t preDimAccumulate = + std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1, + std::multiplies()); + int64_t postDimAccumulate = + std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1, + std::multiplies()); + + // Calculate PyTorch-style gather indices vector + // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1 + // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2 + // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3, + // 0, 0, 0, 1, 1, 1, 2, 2, 2, + // 1, 1, 1, 2, 2, 2, 3, 3, 3] + SmallVector pyTorchIndicesBaseVec; + SmallVector pyTorchIndicesVec; + + for (int64_t window = 0; window < nWindows; window++) { + for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) { + int32_t baseIndex = static_cast(elementIndex + window * step); + for (int64_t i = 0; i < postDimAccumulate; i++) + pyTorchIndicesBaseVec.push_back(baseIndex); + } + } + + for (int64_t i = 0; i < preDimAccumulate; i++) + pyTorchIndicesVec.insert(pyTorchIndicesVec.end(), + pyTorchIndicesBaseVec.begin(), + pyTorchIndicesBaseVec.end()); + + // Create the PyTorch-style indices tensor + // Continuing with the previous example: + // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3) + // pyTorchIndices = tensor([[[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]], + // [[0, 0, 0], + // [1, 1, 1], + // [2, 2, 2], + // [1, 1, 1], + // [2, 2, 2], + // [3, 3, 3]]]) + SmallVector pyTorchIndicesShape(selfShape); + pyTorchIndicesShape[dim] = nWindows * size; + auto pyTorchIndices = + tosa::getConstTensor(rewriter, op, pyTorchIndicesVec, + pyTorchIndicesShape) + .value(); + + // Convert PyTorch-style indices to TensorFlow-style indices + auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self, + pyTorchIndices, dim); + if (!tfIndices) + return rewriter.notifyMatchFailure(op, + "Convert PyTorch-style indices and dim " + "to TensorFlow-style indices failed"); + + // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve + // target elements + auto gatherNdOp = tosa::convertGatherNdOp( + rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy), + self, tfIndices.value()); + if (!gatherNdOp) + return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); + + // Reshape to an intermediary shape where the gathered elements in dimension + // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size' + SmallVector intermediaryShape; + for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) { + if (currentDim == dim) { + intermediaryShape.push_back(nWindows); + intermediaryShape.push_back(size); + } else { + intermediaryShape.push_back(pyTorchIndicesShape[currentDim]); + } + } + + auto reshapeOp = rewriter.create( + op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), + gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape)); + + // Permute dims to the correct result order + SmallVector permutedDims; + for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) { + if (currentDim != dim + 1) + permutedDims.push_back(static_cast(currentDim)); + } + permutedDims.push_back(static_cast(dim + 1)); + + auto permutedDimsConst = tosa::getConstTensor( + rewriter, op, + /*vec=*/permutedDims, + /*shape=*/{static_cast(selfRank + 1)}) + .value(); + + auto result = rewriter.create( + op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst); + + rewriter.replaceOp(op, {result.getResult()}); + + return success(); +} + +} // namespace + +// ----------------------------------------------------------------------------- +// TorchToTosa Pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertTorchToTosa : public ConvertTorchToTosaBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addIllegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + populateTorchToTosaConversionLegalOps(target); + + RewritePatternSet patterns(context); + + auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps( + typeConverter, patterns); + + for (auto op : illegalOps) { + target.addIllegalOp(OperationName(op, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); +} + +std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( + TypeConverter &typeConverter, RewritePatternSet &patterns) { + + MLIRContext *context = patterns.getContext(); + std::set illegalOps; + +#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, \ + context); + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) + INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) +#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN + +#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) +#undef INSERT_UNARY_PATTERN + +#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) + INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) +#undef INSERT_BINARY_PATTERN + +#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) +#undef INSERT_BINARY_ADDSUB_PATTERN + +#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) +#undef INSERT_BINARY_COMPARE_PATTERN + +#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); +#undef INSERT_BINARY_MUL_PATTERN + +#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); +#undef INSERT_BINARY_DIV_PATTERN + +#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); +#undef INSERT_REMAINDER_FMOD_OP_PATTERN + +#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>( \ + typeConverter, context); + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, + mlir::tosa::convertReduceMeanOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, + mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) +#undef INSERT_NDIMS_REDUCTION_OP_PATTERN + +#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>( \ + typeConverter, context); + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) +#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN + +#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>( \ + typeConverter, context); + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) +#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN + +#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); +#undef INSERT_INDICES_REDUCTION_OP_PATTERN + +#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) +#undef INSERT_SQUEEZE_OP_PATTERN + +#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); +#undef INSERT_MATMUL_ATEMOP_PATTERN + +#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MM_ATEMOP_PATTERN + +#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); +#undef INSERT_LINEAR_ATEMOP_PATTERN + +#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, \ + context); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); +#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN + + illegalOps.insert(AtenMaxPool2dOp::getOperationName()); + patterns.add(typeConverter, context); + + illegalOps.insert(AtenMaxPool1dOp::getOperationName()); + patterns.add(typeConverter, context); + + illegalOps.insert(AtenAvgPool2dOp::getOperationName()); + patterns.add(typeConverter, context); + + illegalOps.insert(AtenAvgPool1dOp::getOperationName()); + patterns.add(typeConverter, context); + +#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, \ + context); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); +#undef INSERT_CONSTANT_FILL_PATTERN + +#define INSERT_FILL_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); +#undef INSERT_FILL_PATTERN + +#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); +#undef INSERT_MASKED_FILL_PATTERN + +#define INSERT_POW_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); +#undef INSERT_POW_OP_PATTERN + +#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); + INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); +#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN + +#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, \ + context); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); +#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenLeakyReluOp); + INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + INSERT_ATENOP_PATTERN(AtenConvolutionOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReshapeOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad3dOp); + INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); + INSERT_ATENOP_PATTERN(PrimsSplitDimOp); + INSERT_ATENOP_PATTERN(AtenOuterOp); + INSERT_ATENOP_PATTERN(AtenLogitOp); + INSERT_ATENOP_PATTERN(AtenLog1pOp); + INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenExpm1Op); + INSERT_ATENOP_PATTERN(AtenTanOp); + INSERT_ATENOP_PATTERN(AtenUnfoldOp); +#undef INSERT_ATENOP_PATTERN + +#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ + illegalOps.insert(AtenOp::getOperationName()); \ + patterns.add>(typeConverter, context); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); +#undef INSERT_CLONE_ATENOP_PATTERN + + return illegalOps; +} std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index ae8d347e0cfd..1f18cabd8cb0 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -8,20 +8,15 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include -#include #include #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "llvm/Support/FormatVariadic.h" namespace mlir { @@ -29,6 +24,15 @@ namespace tosa { using namespace mlir::torch::Torch; +// This function is a helper for `convertTorchIndexToTfIndices`. +// +// We convert PyTorch index to TensorFlow-style indices so that we can use +// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather +// and Scatter operators to TOSA using TensorFlow-style indices. +// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow +// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want +// to gather/scatter elements, while in TensorFlow, the indices point directly +// to positions that you want to gather/scatter elements. std::optional createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, SmallVector indicesOneDimShape, int32_t dim, @@ -36,49 +40,55 @@ createOneDimTfIndices(PatternRewriter &rewriter, Operation *op, unsigned indexRank = indexShape.size(); SmallVector indicesVec; // input vec to create tosaConstant SmallVector indicesMetaElement; // torch.meshgrid inputs - int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid) // Create torch.meshgrid inputs // Example: indexShape=[1,4,2] // dim0: indicesMetaElement = torch.arange(0, 1) = [0] // dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3] // dim2: indicesMetaElement = torch.arange(0, 2) = [0,1] - for (int i = 0; i < indexShape[dim]; i++) { + for (int i = 0; i < indexShape[dim]; i++) indicesMetaElement.push_back(i); - } - // Compute total number of meta element repeat times: - // = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim - // dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8 - // dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2 - // dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4 - for (int i = 0; i < static_cast(indexRank); i++) { - if (i == dim) { - continue; - } else { - indicesMetaElementRepeatTimes *= indexShape[i]; - } - } - - if (dim != static_cast(indexShape.size()) - 1) { - // Create one dim indices for index except for last dim - // Create indices raw vector. - // torch.stack(torch.meshgrid) - // dim0: indicesVec = [0 0 0 0 0 0 0 0] - // dim0: indicesVec = [0 0 1 1 2 2 3 3] + int preDimMetaElementRepeatTimes = 1; + int postDimMetaElementRepeatTimes = 1; + + // Compute total number of times meta element range should repeat + // = product(indexShape[0:dim]) + // dim0: preDimMetaElementRepeatTimes = 1 + // dim1: preDimMetaElementRepeatTimes = 1 + // dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4 + for (int i = 0; i < dim; i++) + preDimMetaElementRepeatTimes *= indexShape[i]; + + // Compute total number of times meta element repeat + // = product(indexShape[dim+1:indexRank]) + // dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8 + // dim1: postDimMetaElementRepeatTimes = 2 + // dim2: postDimMetaElementRepeatTimes = 1 + for (int i = dim + 1; i < static_cast(indexRank); i++) + postDimMetaElementRepeatTimes *= indexShape[i]; + + // Example using dim1: + // preDimMetaElementRepeatTimes = 1 + // postDimMetaElementRepeatTimes = 2 + // Using postDimMetaElementRepeatTimes, we get the meta element range: + // [0 0 1 1 2 2 3 3] + // Using preDimMetaElementRepeatTimes, we get the full one dim indices: + // [0 0 1 1 2 2 3 3] + // + // Let's use a clearer example: + // indexShape = [3, 4, 2] + // Target dim = 1 + // => preDimMetaElementRepeatTimes = 3 + // postDimMetaElementRepeatTimes = 2 + // Using postDimMetaElementRepeatTimes, we get the meta element range: + // [0 0 1 1 2 2] + // Using preDimMetaElementRepeatTimes, we get the full one dim indices: + // [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2] + for (int i = 0; i < preDimMetaElementRepeatTimes; i++) { for (size_t elementId = 0; elementId < indicesMetaElement.size(); elementId++) { - for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { - indicesVec.push_back(indicesMetaElement[elementId]); - } - } - } else { // Create the one dim indices for last dim of index - // Create indices raw vector - // dim2: indicesVec= [0 1 0 1 0 1 0 1] - // Caution: indicesVec != [0 0 0 0 1 1 1 1] - for (int i = 0; i < indicesMetaElementRepeatTimes; i++) { - for (size_t elementId = 0; elementId < indicesMetaElement.size(); - elementId++) { + for (int j = 0; j < postDimMetaElementRepeatTimes; j++) { indicesVec.push_back(indicesMetaElement[elementId]); } } @@ -341,7 +351,7 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, // %3 = "tosa.reshape"(%1) {new_shape = [8, 3]} : (tensor<1x4x2x3xi32>) -> // tensor<8x3xi32> Flatten the input indices tensor to an [W, ND] matrix. - auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer( + Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); @@ -368,13 +378,18 @@ std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, if (!flattenedCoeffValue) return std::nullopt; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp, + flattenedCoeffValue.value()) + .failed()) + return std::nullopt; + // Multiply the coefficients by the coordinates // %5 = "tosa.mul"(%3, %4) {shift = 0 : i32} : (tensor<8x3xi32>, // tensor<3xi32>) -> tensor<8x3xi32> auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); // Sum up the products of the coefficients and coordinates // %6 = "tosa.reduce_sum"(%5) {axis = 1 : i64} : (tensor<8x3xi32>) -> @@ -557,11 +572,12 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [0] -> [0,0,0] SmallVector tileShape({W}); // {3} + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); auto tosaFillValuesTileOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()), - tosaFillValuesOneReshapeOp.getResult(), - rewriter.getDenseI64ArrayAttr(tileShape)); + tosaFillValuesOneReshapeOp.getResult(), tileOpMultiples); // [0,0,0] -> [[0,0,0]] SmallVector newTosaFillValuesShape({N, W}); // {1,3} @@ -605,7 +621,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [[0, 1], [0, 2], [0, 3]] -> [[0, 1], [0, 2], [0, 3]] // %11 = "tosa.reshape"(%8) {new_shape = array} : (tensor<3x2xi32>) // -> tensor<3x2xi32> - auto indicesMatrixReshapeOp = tosa::CreateOpAndInfer( + Value indicesMatrixReshapeOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), indicesValue, rewriter.getDenseI64ArrayAttr(indicesMatrixShape)); @@ -632,6 +648,11 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, if (!flattenedCoeffValue) return std::nullopt; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), indicesMatrixReshapeOp, + flattenedCoeffValue.value()) + .failed()) + return std::nullopt; + // Multiply the coefficients by the coordinates. // [[0, 1], [0, 2], [0, 3]] X [4, 1] -> [[4*0, 1*1], [4*0, 1*2], [4*0, 1*3]] // %13 = "tosa.mul"(%11, %12) {shift = 0 : i32} : (tensor<3x2xi32>, @@ -639,7 +660,7 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, auto flattenedIndicesMulOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(indicesMatrixShape, indicesType.getElementType()), - indicesMatrixReshapeOp.getResult(), flattenedCoeffValue.value(), 0); + indicesMatrixReshapeOp, flattenedCoeffValue.value(), 0); // Sum up the products of the coefficients and coordinates // [[4*0 + 1*1], [4*0 + 1*2], [4*0 + 1*3]] = [[1],[2],[3]] @@ -700,6 +721,13 @@ std::optional convertReduceOpCommon( auto input_rank = input_shape.size(); Value val = input_value; + if (output_type.getElementType() != input_type.getElementType()) { + reduce_element_type = output_type.getElementType(); + val = rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(input_shape, reduce_element_type), + val); + } + if (axes_elems.getNumElements() == 0) { // No axes means return the original tensor. auto identity_op = CreateOpAndInfer( @@ -723,10 +751,20 @@ std::optional convertReduceOpCommon( RankedTensorType reduce_type = RankedTensorType::get(shape_vec, reduce_element_type); - auto reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, - val, axis_attr); + Value reduce_op; + if constexpr (std::is_same() || + std::is_same()) { + // Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min + // and tosa.reduce_max + reduce_op = CreateOpAndInfer( + rewriter, op->getLoc(), reduce_type, val, axis_attr, + /*nan_mode=*/rewriter.getStringAttr("PROPAGATE")); + } else { + reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, + val, axis_attr); + } - val = reduce_op.getResult(); + val = reduce_op; } if (is_quantized) { @@ -819,9 +857,9 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype || output_is_qtype) { op->emitOpError("ConvertReduceProdOp: input/output tensor should " @@ -845,9 +883,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " @@ -900,9 +938,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + isa(input_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + isa(output_type.getElementType()); if (input_is_qtype != output_is_qtype) { op->emitOpError("ConvertReduceSumOp: input/output tensor should " @@ -911,7 +949,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, } // Only supports float type mean() if it's non-quantized - if (!input_is_qtype && !output_type.getElementType().isa()) { + if (!input_is_qtype && !isa(output_type.getElementType())) { op->emitWarning( "Failed convertReduceMean: input unquantized type but output element " "not FloatType!"); @@ -962,6 +1000,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, if (!input_is_qtype) { Value div_const = getTosaConstTensorSingleF32(rewriter, op, div_scale); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), val.value(), + div_const) + .failed()) + return std::nullopt; + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, val.value(), div_const, 0) .getResult(); @@ -1010,6 +1054,11 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; } + Value ordValRank0 = ordVal; + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input_value, ordVal) + .failed()) + return std::nullopt; + if (fabs(ordLiteralFloat) < epsilon || fabs(static_cast(ordLiteralInt)) < epsilon) { op->emitOpError("unimplemented: L0 norm"); @@ -1022,19 +1071,33 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; } - auto absVal = CreateOpAndInfer(rewriter, op->getLoc(), - input_type, input_value) + auto input_value_casted = + tosa::promoteType(rewriter, input_value, output_type); + auto absVal = CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(input_type.getShape(), elemType), + input_value_casted) .getResult(); - auto powVal = CreateOpAndInfer(rewriter, op->getLoc(), - input_type, absVal, ordVal) + auto powVal = CreateOpAndInfer( + rewriter, op->getLoc(), + RankedTensorType::get(input_type.getShape(), elemType), + absVal, ordVal) .getResult(); std::optional result = convertReduceSumOp( rewriter, op, output_type, powVal, axes_elems, keep_dims); if (!result) return std::nullopt; - auto reciprocalVal = CreateOpAndInfer( - rewriter, op->getLoc(), ordVal.getType(), ordVal) - .getResult(); + + Value reciprocalVal = + CreateOpAndInfer(rewriter, op->getLoc(), + ordValRank0.getType(), ordValRank0) + .getResult(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), result.value(), + reciprocalVal) + .failed()) + return std::nullopt; + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, result.value(), reciprocalVal) .getResult(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index ab3db75fa85f..2243c8dcfd83 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -8,9 +8,9 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" namespace mlir { namespace tosa { @@ -135,6 +135,18 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, } } +Value buildSlice(PatternRewriter &rewriter, Value &input, + llvm::ArrayRef start, llvm::ArrayRef size) { + assert(start.size() == size.size() && + "Start and Size must have the same size"); + return tosa::CreateOpAndInfer( + rewriter, input.getLoc(), + RankedTensorType::get( + llvm::SmallVector(size.size(), ShapedType::kDynamic), + cast(input.getType()).getElementType()), + input, start, size); +} + // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type) { return (output_element_type.getStorageTypeIntegralWidth() == 8); @@ -183,7 +195,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, num_total_elements *= a; } - if (vec.size() != num_total_elements) { + if (vec.size() != num_total_elements && vec.size() != 1) { op->emitOpError("getConstTensor(): number of elements mismatch."); return std::nullopt; } @@ -217,7 +229,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, num_total_elements *= a; } - if (vec.size() != num_total_elements) { + if (vec.size() != num_total_elements && vec.size() != 1) { op->emitOpError("getConstTensor(): number of elements mismatch."); return std::nullopt; } @@ -247,7 +259,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, num_total_elements *= a; } - if (vec.size() != num_total_elements) { + if (vec.size() != num_total_elements && vec.size() != 1) { op->emitOpError("getConstTensor(): number of elements mismatch."); return std::nullopt; } @@ -255,6 +267,34 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); auto const_attr = DenseElementsAttr::get(const_type, vec); + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + if (dtype) { + return rewriter.createOrFold( + op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); + } + return const_op.getResult(); +} + +// Template specialization for double +template <> +std::optional getConstTensor(PatternRewriter &rewriter, + Operation *op, ArrayRef vec, + ArrayRef shape, + std::optional dtype) { + uint64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + op->emitOpError("getConstTensor(): number of elements mismatch."); + return std::nullopt; + } + + auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); + auto const_attr = DenseElementsAttr::get(const_type, vec); + auto const_op = rewriter.create(op->getLoc(), const_type, const_attr); @@ -265,42 +305,68 @@ std::optional getConstTensor(PatternRewriter &rewriter, return const_op.getResult(); } -static LogicalResult checkValidityOfCast(Type src, Type dest) { +// Valid TOSA casting pairs according to TOSA spec: +// https://www.mlplatform.org/tosa/tosa_spec.html#_cast +// Note: currently TOSA doesn't support casting to and from I64 and F64 +[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) { // clang-format off if ((src == dest) || - // int64 -> * - (src.isInteger(64) && dest.isInteger(32)) || - (src.isInteger(64) && dest.isInteger(8)) || - (src.isInteger(64) && dest.isInteger(1)) || - (src.isInteger(64) && dest.isF32()) || // int32 -> * - (src.isInteger(32) && dest.isInteger(64)) || + (src.isInteger(32) && dest.isInteger(16)) || + (src.isInteger(32) && dest.isInteger(8)) || (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || + (src.isInteger(32) && dest.isF16()) || (src.isInteger(32) && dest.isBF16()) || // int16 -> * + (src.isInteger(16) && dest.isInteger(32)) || + (src.isInteger(16) && dest.isInteger(8)) || + (src.isInteger(16) && dest.isInteger(1)) || (src.isInteger(16) && dest.isBF16()) || + (src.isInteger(16) && dest.isF32()) || + (src.isInteger(16) && dest.isF16()) || // int8 -> * + (src.isInteger(8) && dest.isInteger(32)) || + (src.isInteger(8) && dest.isInteger(16)) || (src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isBF16()) || + (src.isInteger(8) && dest.isF32()) || + (src.isInteger(8) && dest.isF16()) || // int1 -> * - (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || - // f64 -> * - (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || + (src.isInteger(1) && dest.isInteger(32)) || + (src.isInteger(1) && dest.isInteger(16)) || + (src.isInteger(1) && dest.isInteger(8)) || // f32 -> * - (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isInteger(32)) || + (src.isF32() && dest.isInteger(16)) || + (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isBF16()) || (src.isF32() && dest.isF16()) || - (src.isF32() && dest.isInteger(8)) || - (src.isF32() && dest.isInteger(64)) || - (src.isF32() && dest.isInteger(1)) || + (src.isF32() && isa(dest)) || + (src.isF32() && isa(dest)) || + // f16 -> * + (src.isF16() && dest.isInteger(32)) || + (src.isF16() && dest.isInteger(16)) || + (src.isF16() && dest.isInteger(8)) || + (src.isF16() && dest.isBF16()) || + (src.isF16() && dest.isF32()) || + (src.isF16() && isa(dest)) || + (src.isF16() && isa(dest)) || // bf16 -> * - (src.isBF16() && dest.isInteger(8)) || - (src.isBF16() && dest.isInteger(16)) || (src.isBF16() && dest.isInteger(32)) || - (src.isBF16() && dest.isF32())) { + (src.isBF16() && dest.isInteger(16)) || + (src.isBF16() && dest.isInteger(8)) || + (src.isBF16() && dest.isF32()) || + (src.isBF16() && isa(dest)) || + (src.isBF16() && isa(dest)) || + // fp8e4m3 -> * + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16()) || + // fp8e5m2 -> * + (isa(src) && dest.isBF16()) || + (isa(src) && dest.isF32()) || + (isa(src) && dest.isF16())) { return success(); } // clang-format on @@ -311,12 +377,21 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) { LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, Value src, Type destType, Value &result) { - Type srcElemTy = dyn_cast(src.getType()).getElementType(); + TensorType srcType = dyn_cast(src.getType()); + Type srcElemTy = srcType.getElementType(); Type destElemTy = dyn_cast(destType).getElementType(); - if (failed(checkValidityOfCast(srcElemTy, destElemTy))) - return rewriter.notifyMatchFailure( - op, "casting to result dtype is invalid or unsupported"); + // Temporarily disable checkValidityOfCast as it's currently strictly + // following TOSA spec and might cause many e2e tests to fail. This is because + // even though there are some casting pairs that are not congruent to TOSA + // spec, they are still permissible. TOSA validation should flag these illegal + // constructs in a per-profile manner. This strict validity check will be + // enabled later in a potential `--strict` mode which checks for strict + // casting only when needed (the default value of `--strict` mode will be + // off). + // if (failed(checkValidityOfCast(srcElemTy, destElemTy))) + // return rewriter.notifyMatchFailure( + // op, "casting to result dtype is invalid or unsupported"); if (destElemTy.isInteger(1)) { auto srcType = dyn_cast(src.getType()); @@ -334,20 +409,59 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op, SmallVector values(num_total_elements, 0); constOp = tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(8)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isInteger(16)) { + SmallVector values(num_total_elements, 0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else if (srcElemTy.isBF16()) { + SmallVector values(num_total_elements, 0.0); + constOp = + tosa::getConstTensor(rewriter, op, values, srcShape, srcElemTy) + .value(); } else if (srcElemTy.isF32()) { SmallVector values(num_total_elements, 0.0); constOp = tosa::getConstTensor(rewriter, op, values, srcShape).value(); - } else if (srcElemTy.isInteger(8)) { - SmallVector values(num_total_elements, 0); + } else if (srcElemTy.isF64()) { + SmallVector values(num_total_elements, 0.0); constOp = - tosa::getConstTensor(rewriter, op, values, srcShape).value(); + tosa::getConstTensor(rewriter, op, values, srcShape).value(); + } else { + op->dump(); + op->emitError("Unsupported conversion to i1"); + return failure(); } Value equalToZero = rewriter.create(op->getLoc(), destType, src, constOp.value()); result = rewriter.create(op->getLoc(), destType, equalToZero); } else { + if (llvm::isa(srcElemTy) && destElemTy.isInteger()) { + // for float->int conversion, tosa.cast performs round-to-nearest + // torch performs round-to-zero instead + // generate round-to-zero conversion prior to tosa.cast to match with + // expected torch behavior + auto floor = rewriter.create(op->getLoc(), srcType, src); + auto ceil = rewriter.create(op->getLoc(), srcType, src); + + auto zeroValue = + tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); + + if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), src, zeroValue) + .failed()) + return rewriter.notifyMatchFailure( + op, "Failed to equalize ranks among operands and result"); + + auto boolType = srcType.clone(rewriter.getIntegerType(1)); + auto isNegative = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), boolType, zeroValue, src); + src = tosa::CreateOpAndInfer( + rewriter, op->getLoc(), srcType, isNegative, ceil, floor); + } result = rewriter.create(op->getLoc(), destType, src); } return success(); @@ -365,11 +479,47 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { return input; } +TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, + Value val, ArrayRef newShape) { + + auto tensorTy = dyn_cast(val.getType()); + assert(tensorTy); + + auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType()); + return rewriter.create( + loc, newTy, val, rewriter.getDenseI64ArrayAttr(newShape)); +} + +TypedValue transposeBy(Location loc, + PatternRewriter &rewriter, Value val, + ArrayRef permutation) { + auto tensorTy = dyn_cast(val.getType()); + assert(tensorTy); + + auto permType = RankedTensorType::get({(int64_t)permutation.size()}, + rewriter.getI32Type()); + auto permAttr = DenseElementsAttr::get(permType, permutation); + auto permOp = rewriter.create(loc, permType, permAttr); + + SmallVector newShape{tensorTy.getShape()}; + for (size_t i = 0; i < newShape.size(); i++) + newShape[i] = tensorTy.getShape()[permutation[i]]; + + auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType()); + + auto v = rewriter.createOrFold(loc, newTy, val, permOp); + return cast>(v); +} + // Template instantiation template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, ArrayRef shape, std::optional dtype); +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); + template std::optional getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, ArrayRef shape, std::optional dtype); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e014fbeaa9d4..3a5a5a7447c8 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -32,7 +31,7 @@ LogicalResult verifyLinalgCompatibleTypes(Operation *op, return false; auto tensor = dyn_cast(type); return !tensor || - tensor.toBuiltinTensor().dyn_cast_or_null(); + dyn_cast_or_null(tensor.toBuiltinTensor()); }; bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) && @@ -67,7 +66,7 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, // Generate IR: assert(dim >= 0 && dim < inputRank) void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) { - assert(dim.getType().isa() && + assert(isa(dim.getType()) && "dim arg of assertIsValidDim must be integer type"); Value cst0 = b.create(loc, b.getZeroAttr(inputRank.getType())); @@ -133,20 +132,36 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy) { Value initTensor = b.create(loc, getAsOpFoldResult(sizes), elemTy); - RankedTensorType type = cast(initTensor.getType()); - Value c0 = - b.create(loc, b.getZeroAttr(type.getElementType())); + + Type fillValElemTy = elemTy; + if (auto dtypeComplex = dyn_cast(elemTy)) + fillValElemTy = cast(dtypeComplex.getElementType()); + + Value c0 = b.create(loc, b.getZeroAttr(fillValElemTy)); return b.create(loc, c0, initTensor).getResult(0); } +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy) { + Value initTensor = + b.create(loc, getAsOpFoldResult(sizes), elemTy); + + Type fillValElemTy = elemTy; + if (auto dtypeComplex = dyn_cast(elemTy)) + fillValElemTy = cast(dtypeComplex.getElementType()); + + Value c1 = b.create(loc, b.getOneAttr(fillValElemTy)); + return b.create(loc, c1, initTensor).getResult(0); +} + Value castIntToIndex(OpBuilder &b, Location loc, Value v) { - assert(v.getType().isa() && "must be called with integer type"); - return b.create(loc, b.getIndexType(), v); + assert(isa(v.getType()) && "must be called with integer type"); + return b.createOrFold(loc, b.getIndexType(), v); } Value castIndexToInt64(OpBuilder &b, Location loc, Value idx) { - assert(idx.getType().isa() && "must be called with integer type"); - return b.create(loc, b.getI64Type(), idx); + assert(isa(idx.getType()) && "must be called with integer type"); + return b.createOrFold(loc, b.getI64Type(), idx); } SmallVector @@ -320,6 +335,10 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, if (auto dtypeFloat = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) { + if (scalarFloat.getWidth() == 16 && dtypeFloat.getWidth() == 16) { + auto scalarF32 = b.create(loc, b.getF32Type(), scalar); + return b.create(loc, dtype, scalarF32); + } if (scalarFloat.getWidth() > dtypeFloat.getWidth()) return b.create(loc, dtype, scalar); // Only scalarFloat width < dtypeFloat width can reach here. @@ -351,6 +370,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, } if (auto dtypeComplex = dyn_cast(dtype)) { + + // Complex to complex. if (auto scalarComplex = dyn_cast(scalarType)) { auto dtypeElemType = dtypeComplex.getElementType(); @@ -365,6 +386,39 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return b.create(loc, dtypeComplex, realVal, imgVal); } + + // Float to complex type. + if (auto dtypeFloat = dyn_cast(scalarType)) { + auto complexElementType = + cast(dtypeComplex.getElementType()); + Value realVal; + Value imgVal = + b.create(loc, b.getZeroAttr(complexElementType)); + + if (complexElementType.getWidth() > dtypeFloat.getWidth()) { + realVal = b.create(loc, complexElementType, scalar); + } else if (complexElementType.getWidth() < dtypeFloat.getWidth()) { + realVal = b.create(loc, complexElementType, scalar); + } else { + realVal = scalar; + } + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + + // Int to complex type. + if (auto dtypeInt = dyn_cast(scalarType)) { + auto complexElementType = + cast(dtypeComplex.getElementType()); + + Value realVal = + b.create(loc, complexElementType, scalar); + Value imgVal = + b.create(loc, b.getZeroAttr(complexElementType)); + + return b.create(loc, dtypeComplex, realVal, imgVal); + } + mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " << scalarType << "(scalar type) -> " << dtype << "(dtype)"; @@ -376,7 +430,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize) { - if (torchOptionalInt.getType().isa()) + if (isa(torchOptionalInt.getType())) return defaultValue; auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize); Value positiveDim = @@ -397,6 +451,119 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } +// Helper function to unsqueeze the input tensor at given dim. +// Returns the unsqueezed tensor or failure. +FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + ArrayRef inputShape = inputType.getShape(); + + // `input` has a reduced rank. Hence add 1. + int64_t unsqueezedRank = inputShape.size() + 1; + dim = toPositiveDim(dim, unsqueezedRank); + if (!isValidDim(dim, unsqueezedRank)) { + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + } + + SmallVector unsqueezedShape{inputShape}; + unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1); + Type unsqueezedType = + RankedTensorType::get(unsqueezedShape, inputType.getElementType()); + + SmallVector reassociationMap(inputRank); + // From the perspective of the reassociation map, the situation of + // unsqueezing before or after the last dimension is symmetrical. + // Normalize it to the "before" case. + // The 0 case is special here, since there is no last dimension to insert + // before -- we simply rely on the loop below iterating 0 times. + if (dim == inputRank && inputRank != 0) + dim = inputRank - 1; + bool alreadyCrossedExpandedDim = false; + for (int i = 0; i != inputRank; i++) { + if (alreadyCrossedExpandedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (i == dim) { + reassociationMap[i].push_back(i + 1); + alreadyCrossedExpandedDim = true; + } + } + } + Value unsqueezed = rewriter.create( + op->getLoc(), unsqueezedType, input, reassociationMap); + return unsqueezed; +} + +// Helper function to squeeze the input tensor at given dim. +// Returns the squeezed tensor or failure. +FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, + Value input, int64_t dim) { + Location loc = op->getLoc(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + // No scope for squeezing the input. + if (inputRank == 0) + return input; + + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + // assert dynamic squeeze dim size == 1 + if (inputType.isDynamicDim(dim)) { + Value cstDim = rewriter.create(loc, dim); + Value dimVal = rewriter.create(loc, input, cstDim); + Value cstOne = rewriter.create(loc, 1); + Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, + dimVal, cstOne); + rewriter.create( + loc, cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); + } + + ArrayRef inputShape = inputType.getShape(); + SmallVector squeezedShape; + squeezedShape.append(inputShape.begin(), inputShape.begin() + dim); + squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end()); + int64_t squeezedRank = inputRank - 1; + Type squeezedType = + RankedTensorType::get(squeezedShape, inputType.getElementType()); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // squeeze will behave as an identity operation. + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { + return input; + } + + SmallVector reassociationMap(squeezedRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != squeezedRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + Value squeezed = rewriter.create( + op->getLoc(), squeezedType, input, reassociationMap); + return squeezed; +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 218ecad3388f..9a90b4cacaac 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -46,16 +46,17 @@ using namespace mlir::torch::TMTensor; static void getEffectsImpl( SmallVectorImpl> &effects, - ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) { - for (Value value : results) { + ResultRange results, ArrayRef inputBuffers, + ArrayRef outputBuffers) { + for (OpResult value : results) { effects.emplace_back(MemoryEffects::Allocate::get(), value, SideEffects::DefaultResource::get()); } - for (Value value : inputBuffers) { + for (OpOperand *value : inputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); } - for (Value value : outputBuffers) { + for (OpOperand *value : outputBuffers) { effects.emplace_back(MemoryEffects::Read::get(), value, SideEffects::DefaultResource::get()); effects.emplace_back(MemoryEffects::Write::get(), value, @@ -92,14 +93,49 @@ LogicalResult AttentionOp::verify() { Operation *op = getOperation(); ShapedType queryType = getQueryType(); ShapedType keyType = getKeyType(); + ShapedType valueType = getValueType(); + + auto optionalMaskType = getAttnMaskType(); + ShapedType maskType = optionalMaskType ? *optionalMaskType : ShapedType(); + ArrayRef queryShape = queryType.getShape(); ArrayRef keyShape = keyType.getShape(); + ArrayRef valueShape = valueType.getShape(); + ArrayRef maskShape = + optionalMaskType ? maskType.getShape() : ArrayRef(); + for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { - if (keyShape[i] != queryShape[i]) + if (keyShape[i] != queryShape[i]) { return op->emitOpError("query and key batch mismatch"); + } } - if (keyShape.back() != queryShape.back()) + if (keyShape.back() != queryShape.back()) { return op->emitOpError("query and key head dimension mismatch"); + } + + for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { + if (valueShape[i] != queryShape[i]) { + return op->emitOpError("query and value batch dimension mismatch"); + } + } + if (keyShape[keyShape.size() - 2] != valueShape[valueShape.size() - 2]) { + return op->emitOpError("key and value sequence length dimension mismatch"); + } + if (optionalMaskType) { + for (int i = 0, s = maskShape.size() - 2; i < s; ++i) { + if (maskShape[i] != queryShape[i]) { + return op->emitOpError("query and mask batch dimension mismatch"); + } + } + if (maskShape[maskShape.size() - 2] != queryShape[queryShape.size() - 2]) { + return op->emitOpError( + "mask sequence length and query sequence length mismatch"); + } + if (maskShape[maskShape.size() - 1] != keyShape[keyShape.size() - 2]) { + return op->emitOpError( + "mask sequence lengt and key sequence length mismatch"); + } + } return success(); } @@ -167,10 +203,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value query = getQuery(); Value key = getKey(); Value value = getValue(); + + auto optionalMask = getAttnMask(); + Value mask = optionalMask ? *optionalMask : Value(); + Value output = getOutput(); auto queryType = cast(query.getType()); auto keyType = cast(key.getType()); auto valueType = cast(value.getType()); + auto maskType = mask ? cast(mask.getType()) : MemRefType(); auto queryRank = queryType.getRank(); auto keyRank = keyType.getRank(); auto valueRank = valueType.getRank(); @@ -179,6 +220,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value zeroF = b.create(loc, elementType, b.getFloatAttr(elementType, 0.0)); + Value negInfF = b.create( + loc, elementType, + b.getFloatAttr(elementType, -std::numeric_limits::infinity())); // TODO: This needs to be fixed, it assumes everything is dynamic however if // any shapes are static the `memref.alloc` generated is illegal. @@ -213,14 +257,43 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, /*transposed=*/true); // weight = softmax(weight) - Value one = b.create(loc, 1); - Value zero = b.create(loc, 0); Value dim = weightDynSizes[weightRank - 1]; Value scaleFactor = b.create( loc, b.create( loc, elementType, b.create(loc, b.getI32Type(), queryDynSizes[queryRank - 1]))); + + // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) + Value one = b.create(loc, 1); + Value zero = b.create(loc, 0); + b.create( + loc, SmallVector(weightRank, zero), weightDynSizes, + SmallVector(weightRank, one), + [&](OpBuilder &b, Location loc, ValueRange localIVs) { + Value x = b.create(loc, weight, localIVs); + x = b.create(loc, x, scaleFactor); + b.create(loc, x, weight, localIVs); + }); + + // Apply mask to weights if mask is given + if (mask) { + b.create( + loc, SmallVector(weightRank, zero), weightDynSizes, + SmallVector(weightRank, one), + [&](OpBuilder &b, Location loc, ValueRange localIVs) { + Value weightValue = b.create(loc, weight, localIVs); + Value maskValue = b.create(loc, mask, localIVs); + if (maskType.getElementType().isInteger(1)) { + maskValue = + b.create(loc, maskValue, zeroF, negInfF); + } + Value maskedWeight = + b.create(loc, weightValue, maskValue); + b.create(loc, maskedWeight, weight, localIVs); + }); + } + // calculate max(weight) Value init = b.create(loc, weight, SmallVector(weightRank, zero)); @@ -248,7 +321,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, [&](OpBuilder &b, Location loc, ValueRange localIVs) { Value x = b.create(loc, weight, localIVs); x = b.create(loc, x, globalMax); - x = b.create(loc, x, scaleFactor); b.create(loc, x, weight, localIVs); }); // calculate exp(weight) @@ -306,10 +378,19 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, [&](OpBuilder &b, Location loc, ValueRange localIVs) { SmallVector sumIVs(localIVs); sumIVs.pop_back(); + Value x = b.create(loc, weight, localIVs); Value sum = b.create(loc, expWeightSum, sumIVs); - x = b.create(loc, x, sum); - b.create(loc, x, weight, localIVs); + Value divResult = b.create(loc, x, sum); + + // Set to 0 if sum is 0 (can occur during boolean mask / large negative + // QK) + Value isSumZero = + b.create(loc, arith::CmpFPredicate::OEQ, sum, zeroF); + Value result = + b.create(loc, isSumZero, zeroF, divResult); + + b.create(loc, result, weight, localIVs); }); // output = weight @ value @@ -910,12 +991,219 @@ bool SortOp::payloadUsesValueFromOperand(OpOperand *opOperand) { return true; } +//===----------------------------------------------------------------------===// +// TopkOp +//===----------------------------------------------------------------------===// + +LogicalResult TopkOp::verify() { + Operation *op = getOperation(); + if (getNumInputs() != 1 && getNumInputs() != 2) { + return op->emitOpError("expected one or two input operands"); + } + if (getNumOutputs() != 2) { + return op->emitOpError("expected two output operands"); + } + // First check added to eliminate comparison of different int types + if (getInputRank() < 0 || + (getDimension() >= static_cast(getInputRank()))) { + return op->emitOpError("dimension exceeds rank"); + } + // Ensure input/output element types match + auto inputValuesType = cast(values().getType()); + auto outputValuesType = cast(outputValues().getType()); + if (inputValuesType.getElementType() != outputValuesType.getElementType()) { + return op->emitOpError("expected input/output value types to be identical"); + } + // Indices must be int if provided + auto outputIndicesType = cast(outputIndices().getType()); + if (auto inputIndices = indices()) { + auto inputIndicesType = cast(inputIndices->getType()); + if (!inputIndicesType.getElementType().isInteger(32) || + !outputIndicesType.getElementType().isInteger(32)) { + return op->emitOpError("expected input/output indices types to be int32"); + } + } + + // Ranks must match + if (inputValuesType.getRank() != outputValuesType.getRank()) { + return op->emitOpError("expected input/output to have the same rank"); + } + if (auto inputIndices = indices()) { + auto inputIndicesType = cast(inputIndices->getType()); + if (inputIndicesType.getRank() != outputIndicesType.getRank()) { + return op->emitOpError("expected input/output to have the same rank"); + } + } + // Input indicies and values must have the same shape. + if (auto inputIndices = indices()) { + auto inputIndicesType = cast(inputIndices->getType()); + if (failed(verifyCompatibleShape(inputValuesType, inputIndicesType))) + return op->emitOpError("input indices/values shape must match"); + } + // Output indicies and values must have the same shape. + if (failed(verifyCompatibleShape(outputValuesType, outputIndicesType))) + return op->emitOpError("output indices/values shape must match"); + // Input shape must match the output shape except for the dimension() + uint64_t dim = getDimension(); + if (!llvm::all_of(llvm::enumerate(llvm::zip(inputValuesType.getShape(), + outputValuesType.getShape())), + [dim](auto e) { + if (e.index() == dim) { + return true; + } + std::tuple s = e.value(); + return succeeded(verifyCompatibleShape(std::get<0>(s), + + std::get<1>(s))); + })) { + return op->emitOpError("incompatible input/output shapes"); + } + // Check region compatibility + Block &block = getRegion().front(); + if (block.getNumArguments() != 2) { + return op->emitOpError("region block should have 2 arguments"); + } + if (block.getArgument(0).getType() != inputValuesType.getElementType() || + block.getArgument(1).getType() != inputValuesType.getElementType()) { + return op->emitOpError("region block types must match input"); + } + auto terminatorOp = llvm::dyn_cast(block.getTerminator()); + if (!terminatorOp || !terminatorOp.getOperand(0).getType().isInteger(1)) { + return op->emitOpError("region block must end with a linalg_ext.yield i1!"); + } + return success(); +} + +SmallVector TopkOp::getLoopIteratorTypes() { + SmallVector iteratorTypes(getInputRank(), + utils::IteratorType::parallel); + iteratorTypes[getDimension()] = utils::IteratorType::reduction; + return iteratorTypes; +} + +SmallVector TopkOp::getIterationDomain(OpBuilder &builder) { + int64_t operandRank = getInputRank(); + SmallVector loopBounds(operandRank); + Location loc = getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = values(); + for (auto dim : llvm::enumerate(getInputType().getShape())) { + loopBounds[dim.index()].offset = zero; + loopBounds[dim.index()].size = + getDimValue(builder, loc, source, dim.index()); + loopBounds[dim.index()].stride = one; + } + return loopBounds; +} + +LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, + ValueRange ivs) { + uint64_t kDim = getDimension(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + Value initialValue = b.create(loc, values(), ivs); + + // If the indices tensor is not provided, the value index is derived from the + // loop induction variables. + Value initialIndex; + if (indices()) { + initialIndex = b.create(loc, *indices(), ivs); + } else { + Value rawInitialIndex = ivs[kDim]; + initialIndex = + b.create(loc, b.getI32Type(), rawInitialIndex); + } + + // Compute K (ub) from the selected dim of the output + Value ub = b.create(loc, outputValues(), getDimension()); + + // Inner K loop functions: + // Load current K value and index + // Compare N/K using inserted block compare + // Check if N == K using strict weak ordering, select which index came first + // Select new K value from N/K comparison + // Select new K index from N/K comparison or which index came first + // Store new k value and index + // Yield loop carry values after K selection + Value kValue, kIndex; + auto scfFor = b.create( + loc, zero, ub, one, ValueRange{initialValue, initialIndex}, + [&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) { + SmallVector indices(ivs); + indices[kDim] = iv; + kValue = b.create(loc, outputValues(), indices); + kIndex = b.create(loc, outputIndices(), indices); + }); + + SmallVector indices(ivs); + indices[kDim] = scfFor.getInductionVar(); + auto loopCarryValues = scfFor.getRegionIterArgs(); + + // Retrieve region as black box comparision function f(x,y). Plug into op. + auto &srcBlock = getRegion().front(); + IRMapping bvmF; // f(x,y) + IRMapping bvmR; // f(y,x) + { + // Save previous insertion point. Continue within loop body. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointToEnd(&scfFor.getRegion().front()); + SmallVector forwardValues{loopCarryValues[0], kValue}; + SmallVector reverseValues{kValue, loopCarryValues[0]}; + for (auto it : llvm::zip(srcBlock.getArguments(), forwardValues)) { + bvmF.map(std::get<0>(it), std::get<1>(it)); + } + for (auto it : llvm::zip(srcBlock.getArguments(), reverseValues)) { + bvmR.map(std::get<0>(it), std::get<1>(it)); + } + for (auto &blockOp : srcBlock.without_terminator()) { + b.clone(blockOp, bvmF); + b.clone(blockOp, bvmR); + } + Value forwardCmpRes = bvmF.lookup(srcBlock.getTerminator()->getOperand(0)); + Value reverseCmpRes = bvmR.lookup(srcBlock.getTerminator()->getOperand(0)); + + // Check value equality using strictly weak ordering from the region: + // f(x,y) --> forwardCmpRes + // f(y,x) --> reverseCmpRes + // if forwardCmpRes == reverseCmpRes then select which came first + Value cmpValuesEqual = b.create( + loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); + Value cmpFirstIndex = b.create( + loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); + Value combinedCmpEqRes = + b.create(loc, cmpValuesEqual, cmpFirstIndex); + // True if N > K or N came before K + Value indexCmpRes = + b.create(loc, forwardCmpRes, combinedCmpEqRes); + // Select results for K based on comparisons + Value resultKValue = b.create(loc, forwardCmpRes, + loopCarryValues[0], kValue); + Value resultKIndex = + b.create(loc, indexCmpRes, loopCarryValues[1], kIndex); + b.create(loc, resultKValue, outputValues(), indices); + b.create(loc, resultKIndex, outputIndices(), indices); + // Select loop carry, opposite of K results + Value resultCarryValue = b.create( + loc, forwardCmpRes, kValue, loopCarryValues[0]); + Value resultCarryIndex = + b.create(loc, indexCmpRes, kIndex, loopCarryValues[1]); + b.create(loc, ValueRange{resultCarryValue, resultCarryIndex}); + } + return success(); +} + +bool TopkOp::payloadUsesValueFromOperand(OpOperand *opOperand) { + // Set to true so that output operands are always initialized. + return true; +} + #define DEFINE_OP_GET_EFFECTS(OP_NAME) \ void OP_NAME::getEffects( \ SmallVectorImpl> \ &effects) { \ - SmallVector inputBuffers = getInputBufferOperands(); \ - SmallVector outputBuffers = getOutputBufferOperands(); \ + OpOperandVector inputBuffers = getInputBufferOperands(); \ + OpOperandVector outputBuffers = getOutputBufferOperands(); \ getEffectsImpl(effects, getOperation()->getResults(), inputBuffers, \ outputBuffers); \ } @@ -924,6 +1212,7 @@ DEFINE_OP_GET_EFFECTS(AttentionOp) DEFINE_OP_GET_EFFECTS(ScanOp) DEFINE_OP_GET_EFFECTS(ScatterOp) DEFINE_OP_GET_EFFECTS(SortOp) +DEFINE_OP_GET_EFFECTS(TopkOp) namespace { /// This is derived from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp without any diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 6e5a6769a843..3992405a494c 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -121,6 +121,14 @@ class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern { }; namespace { + +static Value materializeToTensor(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, type, inputs[0]); +} + /// Converts TMTensor operations that work on tensor-type operands or results to /// work on buffers. struct TMTensorBufferizePass @@ -133,7 +141,47 @@ struct TMTensorBufferizePass void runOnOperation() override { MLIRContext &context = getContext(); ConversionTarget target(context); - bufferization::BufferizeTypeConverter typeConverter; + // Since the `BufferizeTypeConverter` has been removed here + // https://github.com/llvm/llvm-project/commit/2ff2e871f5e632ea493efaf4f2192f8b18a54ab1, + // hence we have inlined the converter here. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + // Convert RankedTensorType to MemRefType. + typeConverter.addConversion([](RankedTensorType type) -> Type { + return MemRefType::get(type.getShape(), type.getElementType()); + }); + // Convert UnrankedTensorType to UnrankedMemRefType. + typeConverter.addConversion([](UnrankedTensorType type) -> Type { + return UnrankedMemRefType::get(type.getElementType(), 0); + }); + typeConverter.addArgumentMaterialization(materializeToTensor); + typeConverter.addSourceMaterialization(materializeToTensor); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + BaseMemRefType type, + ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1 && "expected exactly one input"); + if (auto inputType = dyn_cast(inputs[0].getType())) { + // MemRef to MemRef cast. + assert(inputType != type && "expected different types"); + // Ranked to unranked casts must be explicit. + auto rankedDestType = dyn_cast(type); + if (!rankedDestType) + return nullptr; + bufferization::BufferizationOptions options; + options.bufferAlignment = 0; + FailureOr replacement = castOrReallocMemRefValue( + builder, inputs[0], rankedDestType, options); + if (failed(replacement)) + return nullptr; + return *replacement; + } + if (isa(inputs[0].getType())) { + // Tensor to MemRef cast. + return builder.create(loc, type, inputs[0]); + } + llvm_unreachable("only tensor/memref input types supported"); + }); // Mark all Standard operations legal. target.addLegalDialect { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 1d0ff41f7845..2e0e3d8e32fd 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -30,6 +30,24 @@ using namespace mlir::torch::Torch; // Utilities //===----------------------------------------------------------------------===// +OpFoldResult genericViewLikeFold(Attribute self, Type resultType) { + auto selfAttr = dyn_cast_or_null(self); + if (!selfAttr) + return nullptr; + + auto resultTy = dyn_cast_or_null(resultType); + if (!resultTy || !resultTy.areAllSizesKnown()) + return nullptr; + + if (selfAttr.isSplat()) { + return SplatElementsAttr::get(resultTy.toBuiltinTensor(), + selfAttr.getSplatValue()); + } + return DenseElementsAttr::get( + resultTy.toBuiltinTensor(), + llvm::to_vector(selfAttr.getValues())); +} + Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, Location loc, Value value, Type desiredType, @@ -128,6 +146,17 @@ static FloatAttr getF64FloatAttr(MLIRContext *context, double value) { return FloatAttr::get(Float64Type::get(context), value); } +static DenseElementsAttr reshapeDenseElementsAttr(DenseElementsAttr attr, + ShapedType newType) { + // TODO: DenseElementsAttr::reshape is broken for bool splats. + // Once that ticket is fixed, we can remove this conditional. + if (attr.isSplat() && newType.getElementType().isInteger(/*width=*/1)) { + auto splatValue = attr.getValues()[0]; + return DenseElementsAttr::get(newType, {splatValue}); + } + return attr.reshape(newType); +} + static Value getScalarIntValue(Value input, Location loc, PatternRewriter &rewriter) { auto inputType = input.getType(); @@ -149,14 +178,12 @@ static Value getScalarIntValue(Value input, Location loc, if (auto valueTensorLiteralOp = input.getDefiningOp()) { if (inputDtype.isInteger(64)) { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); } else { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); return rewriter.create( loc, rewriter.getI64IntegerAttr(val)); @@ -191,8 +218,7 @@ static Value getScalarFloatValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() + auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue() .getValueAsDouble(); return rewriter.create( @@ -548,6 +574,24 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenDotOp +//===----------------------------------------------------------------------===// + +void AtenDotOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenDotOp op, PatternRewriter &rewriter) { + auto ty = dyn_cast(op.getResult().getType()); + if (!ty || !ty.hasSizes() || !ty.hasDtype()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + op.getSelf(), op.getTensor()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // RuntimeAssertOp //===----------------------------------------------------------------------===// @@ -706,6 +750,41 @@ OpFoldResult Aten__Not__Op::fold(FoldAdaptor adaptor) { return IntegerAttr::get(IntegerType::get(getContext(), 1), !value); } +//===----------------------------------------------------------------------===// +// Aten__Or__Op +//===----------------------------------------------------------------------===// + +OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { + auto valueA = dyn_cast_or_null(adaptor.getA()); + auto valueB = dyn_cast_or_null(adaptor.getB()); + if (!valueA && !valueB) + return nullptr; + if ((valueA && valueA.getValue() == 1) || (valueB && valueB.getValue() == 1)) + return IntegerAttr::get(IntegerType::get(getContext(), 1), 1); + if (valueA && valueA.getValue() == 0) + return getB(); + if (valueB && valueB.getValue() == 0) + return getA(); + // unreachable + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenEqBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqBoolOp::fold(FoldAdaptor adaptor) { + if (getOperand(0) == getOperand(1)) + return IntegerAttr::get(IntegerType::get(getContext(), 1), true); + + auto intAttrA = dyn_cast_or_null(adaptor.getA()); + auto intAttrB = dyn_cast_or_null(adaptor.getB()); + if (!intAttrA || !intAttrB) + return nullptr; + return IntegerAttr::get(IntegerType::get(getContext(), 1), + intAttrA.getValue() == intAttrB.getValue()); +} + //===----------------------------------------------------------------------===// // AtenNeBoolOp //===----------------------------------------------------------------------===// @@ -714,12 +793,12 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); - bool a, b; - if (!matchPattern(getOperand(0), m_TorchConstantBool(&a))) + auto intAttrA = dyn_cast_or_null(adaptor.getA()); + auto intAttrB = dyn_cast_or_null(adaptor.getB()); + if (!intAttrA || !intAttrB) return nullptr; - if (!matchPattern(getOperand(1), m_TorchConstantBool(&b))) - return nullptr; - return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); + return IntegerAttr::get(IntegerType::get(getContext(), 1), + intAttrA.getValue() != intAttrB.getValue()); } //===----------------------------------------------------------------------===// @@ -783,11 +862,22 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { - if (getOperand(0).getType() != getResult().getType()) + auto inType = dyn_cast(getOperand(0).getType()); + auto outType = dyn_cast(getResult().getType()); + if (!inType || !outType || !inType.areAllSizesKnown() || + !outType.areAllSizesKnown() || !inType.hasDtype() || + !outType.hasDtype()) { return nullptr; - if (auto tensorType = dyn_cast(getOperand(0).getType())) { - if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) - return getOperand(0); + } + + if (inType == outType) { + return getOperand(0); + } + + DenseElementsAttr input = + dyn_cast_or_null(adaptor.getSelf()); + if (input) { + return reshapeDenseElementsAttr(input, outType.toBuiltinTensor()); } return nullptr; } @@ -993,6 +1083,8 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns, //===----------------------------------------------------------------------===// OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { + if (auto genericFold = genericViewLikeFold(adaptor.getSelf(), getType())) + return genericFold; auto inputType = dyn_cast(getOperand(0).getType()); if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1) return nullptr; @@ -1055,6 +1147,35 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenMulLeftTOp +//===----------------------------------------------------------------------===// + +void AtenMulLeftTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // `[1,2] * 3` -> `[1,2,1,2,1,2]`, if it is not mutated. + patterns.add(+[](AtenMulLeftTOp op, PatternRewriter &rewriter) { + auto listLiteral = op.getL().getDefiningOp(); + if (!listLiteral || isListPotentiallyMutated(listLiteral)) + return failure(); + + int64_t numReps; + if (!matchPattern(op.getN(), m_TorchConstantInt(&numReps))) + return failure(); + + SmallVector newListElements; + for (int rep = 0; rep < numReps; ++rep) { + for (auto operand : listLiteral.getOperands()) { + newListElements.push_back(operand); + } + } + + rewriter.replaceOpWithNewOp(op, op.getL().getType(), + newListElements); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenMinOtherOp //===----------------------------------------------------------------------===// @@ -1187,30 +1308,6 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, // NAry folder helpers //===----------------------------------------------------------------------===// -static bool checkSameDTypes(llvm::ArrayRef attrs) { - bool allFp = true; - bool allInt = true; - - for (auto attr : attrs) { - if (!attr) - return false; - - Type attrty; - if (auto dense = dyn_cast_or_null(attr)) - attrty = dense.getType(); - if (auto fp = dyn_cast_or_null(attr)) - attrty = fp.getType(); - if (auto integer = dyn_cast_or_null(attr)) - attrty = integer.getType(); - if (auto shaped = dyn_cast_or_null(attrty)) - attrty = shaped.getElementType(); - allFp &= isa(attrty); - allInt &= isa(attrty); - } - - return allFp || allInt; -} - static bool checkAllSplats(llvm::ArrayRef attrs) { for (auto attr : attrs) { if (auto dense = dyn_cast_or_null(attr)) { @@ -1226,15 +1323,38 @@ llvm::SmallVector getFoldValueAtIndexFp(llvm::ArrayRef attrs, int64_t idx = 0) { llvm::SmallVector splattrs; + // Note that i1 is neither signed nor unsigned. + // But we should trait i1 as unsigned, otherwise that + // APInt(1,1).getSExtValue() return allOnes 64-bit integer. + // So here only distinguish signed integer. + auto convertAPIntToDouble = [](APInt value, bool isSigned) -> double { + if (isSigned) + return static_cast(value.getSExtValue()); + else + return static_cast(value.getZExtValue()); + }; + for (auto attr : attrs) { - if (auto dense = dyn_cast(attr)) { + if (auto dense = dyn_cast(attr)) { if (dense.isSplat()) { splattrs.push_back(dense.getSplatValue().convertToDouble()); } else { splattrs.push_back(dense.getValues()[idx].convertToDouble()); } - } else if (auto intattr = dyn_cast(attr)) { - splattrs.push_back(intattr.getValueAsDouble()); + } else if (auto dense = dyn_cast(attr)) { + bool isSigned = cast(dense.getElementType()).isSigned(); + if (dense.isSplat()) { + splattrs.push_back( + convertAPIntToDouble(dense.getSplatValue(), isSigned)); + } else { + splattrs.push_back( + convertAPIntToDouble(dense.getValues()[idx], isSigned)); + } + } else if (auto fpattr = dyn_cast(attr)) { + splattrs.push_back(fpattr.getValueAsDouble()); + } else if (auto intattr = dyn_cast(attr)) { + bool isSigned = cast(intattr.getType()).isSigned(); + splattrs.push_back(convertAPIntToDouble(intattr.getValue(), isSigned)); } else { return {}; } @@ -1249,13 +1369,9 @@ llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, llvm::SmallVector splattrs; for (auto attr : attrs) { - // Note that i1 is neither signed nor unsigned. - // But we should trait i1 as unsigned, otherwise that - // APInt(1,1).getSExtValue() return allOnes 64-bit integer. - // So here only distinguish signed integer. bool isSigned = false; - if (auto dense = dyn_cast(attr)) { - isSigned = dyn_cast(dense.getElementType()).isSigned(); + if (auto dense = dyn_cast(attr)) { + isSigned = cast(dense.getElementType()).isSigned(); if (dense.isSplat()) { splattrs.push_back(dense.getSplatValue()); } else { @@ -1268,6 +1384,10 @@ llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, return {}; } + // Note that i1 is neither signed nor unsigned. + // But we should trait i1 as unsigned, otherwise that + // APInt(1,1).getSExtValue() return allOnes 64-bit integer. + // So here only distinguish signed integer. auto &apint = splattrs.back(); if (apint.getBitWidth() < bitwidth) { if (isSigned) { @@ -1284,19 +1404,22 @@ llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, using NAryFoldFpOperator = std::function)>; using NAryFoldIntOperator = std::function)>; -static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, - NAryFoldFpOperator fpFolder, - NAryFoldIntOperator intFolder) { - constexpr int64_t maxFold = 16; - if (!checkSameDTypes(operands)) - return nullptr; +static OpFoldResult +naryFolderHelper(ArrayRef operands, Type ty, + std::optional fpFolder, + std::optional intFolder) { + constexpr int64_t kMaxFold = 16; + for (auto attr : operands) { + if (!attr) + return nullptr; + } auto resultTy = dyn_cast(ty); - if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) + if (!resultTy || !resultTy.hasDtype() || !resultTy.areAllSizesKnown()) return nullptr; auto dty = resultTy.getDtype(); - auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + auto resultBTy = resultTy.toBuiltinTensor(); auto fpTy = dyn_cast(dty); auto intTy = dyn_cast(dty); @@ -1304,10 +1427,7 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, return nullptr; bool allSplats = checkAllSplats(operands); - bool withinMaxFold = - resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold; - - if (!allSplats && !withinMaxFold) + if (!(allSplats || resultBTy.getNumElements() <= kMaxFold)) return nullptr; // We do not support broadcasting in the non-splat case so validate same @@ -1331,10 +1451,15 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements(); if (fpTy) { + if (!fpFolder.has_value()) + return nullptr; + auto folder = fpFolder.value(); llvm::SmallVector folded; for (int i = 0, s = numValues; i < s; ++i) { auto inputs = getFoldValueAtIndexFp(operands, i); - double fold = fpFolder(inputs); + if (inputs.size() != operands.size()) + return nullptr; + double fold = folder(inputs); APFloat val(fold); bool unused; @@ -1346,11 +1471,16 @@ static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, } if (intTy) { + if (!intFolder.has_value()) + return nullptr; + auto folder = intFolder.value(); llvm::SmallVector folded; for (int i = 0, s = numValues; i < s; ++i) { auto inputs = getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i); - folded.push_back(intFolder(inputs)); + if (inputs.size() != operands.size()) + return nullptr; + folded.push_back(folder(inputs)); } return DenseElementsAttr::get(resultBTy, folded); } @@ -1470,6 +1600,24 @@ void AtenRsubScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +// ===----------------------------------------------------------------------===// +// AtenRSubScalarOp +// ===----------------------------------------------------------------------===// + +OpFoldResult AtenRsubScalarOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[1] - inputs[0] * inputs[2]; + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[1] - inputs[0] * inputs[2]; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenMulTensorOp //===----------------------------------------------------------------------===// @@ -1506,7 +1654,7 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { if (!ty || !ty.hasDtype() || !ty.hasSizes()) return nullptr; - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -1612,15 +1760,10 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, constexpr int64_t kMaxFold = 16; if (!lhs || !rhs || !resultTy) return nullptr; - if (!resultTy.hasSizes() || !resultTy.hasDtype()) + if (!resultTy.areAllSizesKnown() || !resultTy.hasDtype()) return nullptr; - for (auto size : resultTy.getSizes()) - if (size == Torch::kUnknownSize) - return nullptr; - auto ctx = lhs.getContext(); - auto resultETy = resultTy.getDtype(); auto tensorETy = cast(lhs.getType()).getElementType(); if (lhs.isSplat()) { if (auto intAttr = dyn_cast(rhs)) { @@ -1632,8 +1775,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); auto resultBool = intFolder(tensorAP, scalarAP, unsign); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1642,8 +1784,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - resultAP); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), resultAP); } return nullptr; } @@ -1666,8 +1807,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, auto resultBool = intFolder(tensorAP, scalarAP, unsign); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } if (auto floatAttr = dyn_cast(rhs)) { @@ -1678,8 +1818,7 @@ static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); values.push_back(resultBool); } - return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), - values); + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), values); } return nullptr; @@ -1811,75 +1950,19 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) { // AtenLogOp //===----------------------------------------------------------------------===// -using UnaryPromoteFpOperator = std::function; -using UnaryPromoteIntOperator = std::function; - -static OpFoldResult unaryPromoteFolder(DenseElementsAttr operand, - ValueTensorType resultTy, - UnaryPromoteFpOperator fpFolder, - UnaryPromoteIntOperator intFolder) { - constexpr int64_t kMaxFold = 16; - if (!resultTy.hasDtype() || !resultTy.hasSizes()) - return nullptr; - if (!isa(resultTy.getDtype())) - return nullptr; - - auto fpTy = dyn_cast(operand.getType().getElementType()); - auto intTy = dyn_cast(operand.getType().getElementType()); - if (!fpTy && !intTy) - return nullptr; - - auto resultBTy = resultTy.toBuiltinTensor().clone(resultTy.getDtype()); - bool splat = operand.isSplat(); - bool withinMaxFold = - resultBTy.hasStaticShape() && resultBTy.getNumElements() <= kMaxFold; - if (!splat && !withinMaxFold) - return nullptr; - - const int64_t numValues = splat ? 1 : resultBTy.getNumElements(); - - llvm::SmallVector operands = {operand}; - llvm::SmallVector folded; - for (int i = 0, s = numValues; i < s; ++i) { - double fold = 0.0; - if (fpTy) { - auto inputs = getFoldValueAtIndexFp(operands, i); - fold = fpFolder(inputs[0]); - } - if (intTy) { - auto inputs = - getFoldValueAtIndexInt(operands, intTy.getIntOrFloatBitWidth(), i); - fold = intFolder(inputs[0], intTy.isSigned()); - } - - APFloat val(fold); - bool unused; - val.convert( - cast(resultBTy.getElementType()).getFloatSemantics(), - APFloat::rmNearestTiesToEven, &unused); - folded.push_back(val); - } - return DenseElementsAttr::get(resultBTy, folded); -} - OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) { auto self = dyn_cast_or_null(adaptor.getSelf()); auto resultType = dyn_cast(getType()); if (!self || !resultType) return nullptr; - // Note that i1 is neither signed nor unsigned. - // But we should trait i1 as unsigned, otherwise that - // APInt(1,1).getSExtValue() return allOnes 64-bit integer. - auto intFold = [](APInt a, bool isSigned) -> double { - if (isSigned) - return std::log(a.getSExtValue()); - else - return std::log(a.getZExtValue()); + auto fpFold = [](llvm::ArrayRef inputs) -> double { + assert(inputs.size() == 1); + return std::log(inputs[0]); }; - auto fpFold = [](double a) -> double { return std::log(a); }; - return unaryPromoteFolder(self, resultType, fpFold, intFold); + return naryFolderHelper(adaptor.getOperands(), resultType, fpFold, + std::nullopt); } //===----------------------------------------------------------------------===// @@ -1928,7 +2011,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) { auto resultType = dyn_cast(getType()); if (resultType && resultType.hasDtype() && - resultType.getDtype().isa()) { + isa(resultType.getDtype())) { return getSelf(); } return {}; @@ -1979,6 +2062,58 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns( }); } +// ===----------------------------------------------------------------------===// +// AtenDivTensorModeOp +// ===----------------------------------------------------------------------===// + +OpFoldResult AtenDivTensorModeOp::fold(FoldAdaptor adaptor) { + auto resultTy = dyn_cast_or_null(getType()); + if (!resultTy || !resultTy.hasDtype()) { + return nullptr; + } + std::function)> fpFold; + std::function)> intFold; + + auto roundMode = dyn_cast_or_null(adaptor.getRoundingMode()); + auto unsign = false; + if (isa(resultTy.getDtype())) { + unsign = cast(resultTy.getDtype()).isUnsigned(); + } + + fpFold = [roundMode](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + if (!roundMode) { + return (double)inputs[0] / inputs[1]; + } else if (roundMode.getValue().str() == "floor") { + return std::floor((double)inputs[0] / inputs[1]); + } else { + return std::trunc((double)inputs[0] / inputs[1]); + } + }; + + intFold = [unsign, roundMode](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + auto lhs = unsign ? inputs[0].getZExtValue() : inputs[0].getSExtValue(); + auto rhs = unsign ? inputs[1].getZExtValue() : inputs[1].getSExtValue(); + int64_t bits = std::max(inputs[0].getBitWidth(), inputs[1].getBitWidth()); + int64_t res; + if (roundMode.getValue().str() == "floor") { + res = std::floor(lhs / rhs); + } else { + res = std::trunc(lhs / rhs); + } + return APInt(bits, res); + }; + + if (!roundMode) { + return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(), + fpFold, std::nullopt); + } + + return naryFolderHelper({adaptor.getSelf(), adaptor.getOther()}, getType(), + fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenDivScalarModeOp //===----------------------------------------------------------------------===// @@ -2118,7 +2253,7 @@ traceKnownSizeTensorType(Value value, std::optional dim) { // Limit the loop count to 6 to avoid indefinite compilation times from // unbounded IR traversals. for (auto idx = 0; idx < 6; ++idx) { - if (!value || !value.getType().isa()) + if (!value || !isa(value.getType())) return failure(); auto tensorType = cast(value.getType()); @@ -2166,6 +2301,126 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenFlattenUsingIntsOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenFlattenUsingIntsOp::fold(FoldAdaptor adaptor) { + return genericViewLikeFold(adaptor.getSelf(), getType()); +} + +//===----------------------------------------------------------------------===// +// AtenUnflattenIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenUnflattenIntOp::fold(FoldAdaptor adaptor) { + return genericViewLikeFold(adaptor.getSelf(), getType()); +} + +void AtenUnflattenIntOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + // if there are only two sizes and one of them is statically 1, then convert + // to an unqueeze. + patterns.add(+[](AtenUnflattenIntOp op, PatternRewriter &rewriter) { + SmallVector sizeValues; + if (!getListConstructElements(op.getSizes(), sizeValues)) + return rewriter.notifyMatchFailure(op, + "sizes must come from list construct"); + if (sizeValues.size() != 2) + return failure(); + int64_t dim0, dim1; + bool dim0Constant = matchPattern(sizeValues[0], m_TorchConstantInt(&dim0)); + bool dim1Constant = matchPattern(sizeValues[1], m_TorchConstantInt(&dim1)); + if (!dim0Constant && !dim1Constant) + return failure(); + if (dim0 != 1 && dim1 != 1) + return failure(); + Value unflattenDim = op.getDim(); + int64_t dimAsInt; + bool dimWasConstant = + matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt)); + Value self = op.getSelf(); + Value cstMOne = rewriter.create(op.getLoc(), -1); + // the runtime asserts below are introduced to catch malformed unflatten ops + // possibly generated from onnx IR. + Value unsqueeze; + if (dim0 == 1) { + // unsqueeze at dim + FailureOr maybeUnsqueeze = + Torch::unsqueezeTensor(rewriter, op, self, unflattenDim); + if (failed(maybeUnsqueeze)) + return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op"); + unsqueeze = maybeUnsqueeze.value(); + // check if the remaining size value is either -1 or equal to original + // size at dim + Value selfSizeAtDim = + rewriter.create(op.getLoc(), self, unflattenDim); + Value isSameSize = rewriter.create( + op.getLoc(), selfSizeAtDim, sizeValues[1]); + Value isMinusOne = + rewriter.create(op.getLoc(), cstMOne, sizeValues[1]); + Value isMOneOrSameSize = rewriter.create( + op.getLoc(), isMinusOne, isSameSize); + rewriter.create( + op.getLoc(), isMOneOrSameSize, + rewriter.getStringAttr("unflatten sizes must be compatible")); + } + if (dim1 == 1) { + // unsqueeze at dim + 1 + Value dimPlusOne; + if (!dimWasConstant) { + Value cstOne = rewriter.create(op.getLoc(), 1); + dimPlusOne = + rewriter.create(op.getLoc(), unflattenDim, cstOne); + } else { + // If dim was constant, creating an AtenAddIntOp will make + // Torch::unsqueezeTensor() interpret it as still not being a constant, + // and the resultant shape would consist of only dynamic dims. To fix + // this, emit a ConstantIntOp for (dim + 1) to avoid an assertion + // failure, when AtenUnsqueezeOp is in a later pass converted to + // ExpandShapeOp, which is bound to fail shape inference in MLIR if + // output dims are dynamic. + dimPlusOne = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1)); + } + FailureOr maybeUnsqueeze = + Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne); + if (failed(maybeUnsqueeze)) + return rewriter.notifyMatchFailure(op, "failed to create unsqueeze op"); + unsqueeze = maybeUnsqueeze.value(); + // check if the remaining size value is either -1 or equal to original + // size at dim + Value selfSizeAtDim = + rewriter.create(op.getLoc(), self, unflattenDim); + Value isSameSize = rewriter.create( + op.getLoc(), selfSizeAtDim, sizeValues[0]); + Value isMinusOne = + rewriter.create(op.getLoc(), cstMOne, sizeValues[0]); + Value isMOneOrSameSize = rewriter.create( + op.getLoc(), isMinusOne, isSameSize); + rewriter.create( + op.getLoc(), isMOneOrSameSize, + rewriter.getStringAttr("unflatten sizes must be compatible")); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + unsqueeze); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenReshapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenReshapeOp::fold(FoldAdaptor adaptor) { + auto selfTy = dyn_cast(getSelf().getType()); + auto opTy = dyn_cast(getType()); + if (selfTy && selfTy == opTy && selfTy.hasSizes() && selfTy.hasDtype() && + selfTy.toBuiltinTensor().hasStaticShape()) + return getSelf(); + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenSelectIntOp //===----------------------------------------------------------------------===// @@ -2177,7 +2432,7 @@ OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { return nullptr; auto selfTy = cast(self.getType()); - auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + auto bty = ty.toBuiltinTensor(); if (!bty.hasStaticShape()) return nullptr; @@ -2399,7 +2654,7 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) { - StringAttr item = dyn_cast(adaptor.getItem()); + StringAttr item = dyn_cast_or_null(adaptor.getItem()); if (!item) return nullptr; @@ -2500,7 +2755,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { // Constant fold int -> float conversion. - if (auto integerAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto integerAttr = dyn_cast_or_null(adaptor.getA())) { return FloatAttr::get( mlir::Float64Type::get(getContext()), static_cast(integerAttr.getValue().getSExtValue())); @@ -2517,7 +2772,7 @@ OpFoldResult AtenFloatScalarOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. - if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto floatAttr = dyn_cast_or_null(adaptor.getA())) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); @@ -2531,7 +2786,7 @@ OpFoldResult AtenIntFloatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenIntScalarOp::fold(FoldAdaptor adaptor) { // Constant fold float -> int conversion. - if (auto floatAttr = adaptor.getA().dyn_cast_or_null()) { + if (auto floatAttr = dyn_cast_or_null(adaptor.getA())) { return IntegerAttr::get( mlir::IntegerType::get(getContext(), 64), static_cast(floatAttr.getValue().convertToDouble())); @@ -2581,7 +2836,8 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns( OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) { // note: memory_format would be ignored - if (llvm::dyn_cast(getSelf().getType())) { + if (getSelf().getType() == getResult().getType() && + llvm::dyn_cast(getSelf().getType())) { // self should have value semantics return getSelf(); } @@ -2640,8 +2896,7 @@ LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, if (!indicesTensorType.hasDtype()) return failure(); - auto indicesType = - indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + auto indicesType = indicesTensorType.toBuiltinTensor(); if (!indicesType || !indicesType.hasStaticShape()) return failure(); @@ -2676,9 +2931,8 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = properties.as() - ->getValue() - .dyn_cast_or_null(); + auto attr = + dyn_cast_or_null(properties.as()->getValue()); if (!attr) return failure(); RankedTensorType tensorType = cast(attr.getType()); @@ -2704,10 +2958,10 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) { bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) { - if (!actual[0].isa()) + if (!isa(actual[0])) return false; - return areSizesAndDtypesCompatible(inferred[0].cast(), - actual[0].cast()); + return areSizesAndDtypesCompatible(cast(inferred[0]), + cast(actual[0])); } //===----------------------------------------------------------------------===// @@ -2718,9 +2972,8 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { - auto attr = properties.as() - ->getValue() - .dyn_cast_or_null(); + auto attr = + dyn_cast_or_null(properties.as()->getValue()); if (!attr) return failure(); RankedTensorType tensorType = cast(attr.getType()); @@ -2741,8 +2994,8 @@ OpFoldResult ValueTensorLiteralOp::fold(FoldAdaptor adaptor) { bool TensorStaticInfoCastOp::areCastCompatible(mlir::TypeRange inputs, mlir::TypeRange outputs) { - return areSizesAndDtypesCompatible(inputs[0].cast(), - outputs[0].cast()); + return areSizesAndDtypesCompatible(cast(inputs[0]), + cast(outputs[0])); } void TensorStaticInfoCastOp::getCanonicalizationPatterns( @@ -2802,7 +3055,8 @@ LogicalResult CopyToNonValueTensorOp::inferReturnTypes( void CopyToNonValueTensorOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Allocate::get(), getResult()); + effects.emplace_back(MemoryEffects::Allocate::get(), + getOperation()->getOpResult(0)); } //===----------------------------------------------------------------------===// @@ -2829,7 +3083,8 @@ LogicalResult CopyToValueTensorOp::inferReturnTypes( void CopyToValueTensorOp::getEffects( SmallVectorImpl> &effects) { - effects.emplace_back(MemoryEffects::Read::get(), getOperand()); + effects.emplace_back(MemoryEffects::Read::get(), + &getOperation()->getOpOperand(0)); } //===----------------------------------------------------------------------===// @@ -2872,11 +3127,11 @@ void ConstantDeviceOp::getAsmResultNames( ParseResult ConstantIntOp::parse(OpAsmParser &parser, OperationState &result) { Builder builder(result.getContext()); result.addTypes(builder.getType()); - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); int64_t value; if (parser.parseInteger(value)) return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); result.addAttribute("value", builder.getI64IntegerAttr(value)); return success(); } @@ -3031,6 +3286,33 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenMeshgridOp +//===----------------------------------------------------------------------===// +void AtenMeshgridOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMeshgridOp op, PatternRewriter &rewriter) { + Value constIndexing = rewriter.create( + op->getLoc(), rewriter.getStringAttr("ij")); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getTensors(), constIndexing); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenSplitSizesOp +//===----------------------------------------------------------------------===// + +void AtenSplitSizesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenSplitSizesOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenIsFloatingPointOp //===----------------------------------------------------------------------===// @@ -3040,7 +3322,7 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { if (!operandType) return nullptr; if (operandType.hasDtype()) { - bool isFloatType = operandType.getDtype().isa(); + bool isFloatType = isa(operandType.getDtype()); return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType); } // doesn't has dtype @@ -3098,12 +3380,12 @@ void AtenSliceTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, int64_t start; int64_t end; int64_t step; - if (op.getStart().getType().isa()) { + if (isa(op.getStart().getType())) { start = 0; } else if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) { return failure(); } - if (op.getEnd().getType().isa()) { + if (isa(op.getEnd().getType())) { end = listElements.size(); } else if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) { return failure(); @@ -3196,7 +3478,7 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // things. Value replacement = tupleConstruct.getElements()[i]; if (replacement.getType() != op.getType()) { - if (op.getType().isa()) { + if (isa(op.getType())) { replacement = rewriter.create( op.getLoc(), op.getType(), replacement); } else { @@ -3267,7 +3549,18 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, if (op->getNumResults() != listConstruct.getElements().size()) return failure(); - rewriter.replaceOp(op, listConstruct.getElements()); + SmallVector unpacked; + for (int i = 0, s = op->getNumResults(); i < s; ++i) { + auto element = listConstruct.getElements()[i]; + if (element.getType() != op->getResult(i).getType()) { + element = rewriter.create( + op.getLoc(), op->getResult(i).getType(), element); + } + + unpacked.push_back(element); + } + + rewriter.replaceOp(op, unpacked); return success(); }); } @@ -3352,8 +3645,8 @@ using BinaryIntOperatorFn = std::function; static OpFoldResult atenBinaryIntOperatorFoldHelper(ArrayRef operands, BinaryIntOperatorFn f) { - auto intLhs = operands[0].dyn_cast_or_null(); - auto intRhs = operands[1].dyn_cast_or_null(); + auto intLhs = dyn_cast_or_null(operands[0]); + auto intRhs = dyn_cast_or_null(operands[1]); if (!intLhs || !intRhs) { return nullptr; } @@ -3388,7 +3681,11 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, // AtenAliasOp //===----------------------------------------------------------------------===// -OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); } +OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { + if (getOperand().getType() != getResult().getType()) + return {}; + return getOperand(); +} //===----------------------------------------------------------------------===// // AtenFloordivIntOp @@ -3400,6 +3697,44 @@ OpFoldResult AtenFloordivIntOp::fold(FoldAdaptor adaptor) { [](int64_t a, int64_t b) { return std::floor(a / (double)b); }); } +void AtenFloordivIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenFloordivIntOp op, PatternRewriter &rewriter) { + int64_t lhs, rhs; + bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs)); + bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs)); + if (lConstant && rConstant) + return failure(); + if (lConstant || rConstant) { + int64_t firstConstant = lConstant ? lhs : rhs; + Value firstOperand = lConstant ? op.getB() : op.getA(); + if (firstOperand.getDefiningOp() && + firstOperand.getDefiningOp()) { + auto prevMulIntOp = firstOperand.getDefiningOp(); + int64_t prevLhs, prevRhs; + bool prevLConstant = + matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs)); + bool prevRConstant = + matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs)); + if (prevLConstant && prevRConstant) + return failure(); + if ((prevLConstant || prevRConstant) && + prevMulIntOp->hasOneUse() == 1) { + int64_t secondConstant = prevLConstant ? prevLhs : prevRhs; + if (secondConstant == firstConstant) { + rewriter.replaceAllUsesWith( + op.getResult(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0)); + rewriter.eraseOp(op); + rewriter.eraseOp(prevMulIntOp); + return success(); + } + } + } + } + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenRemainderIntOp //===----------------------------------------------------------------------===// @@ -3409,11 +3744,45 @@ OpFoldResult AtenRemainderIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](int64_t a, int64_t b) { return a % b; }); } +// ===----------------------------------------------------------------------===// +// AtenRemainderScalarOp +// ===----------------------------------------------------------------------===// + +OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) { + auto resultTy = dyn_cast_or_null(getType()); + if (!resultTy || !resultTy.hasDtype()) { + return nullptr; + } + + auto unsign = false; + if (isa(resultTy.getDtype())) { + unsign = cast(resultTy.getDtype()).isUnsigned(); + } + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return std::fmod(inputs[0], inputs[1]); + }; + + auto intFold = [unsign](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + auto ret = unsign ? inputs[0].urem(inputs[1]) : inputs[0].srem(inputs[1]); + return ret; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenAddIntOp //===----------------------------------------------------------------------===// OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { + auto intLhs = dyn_cast_or_null(adaptor.getA()); + auto intRhs = dyn_cast_or_null(adaptor.getB()); + if (intRhs && intRhs.getValue().getSExtValue() == 0) + return getA(); + if (intLhs && intLhs.getValue().getSExtValue() == 0) + return getB(); return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); } @@ -3423,10 +3792,76 @@ OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { + if (getA() == getB()) + return IntegerAttr::get( + IntegerType::get(getContext(), 64, IntegerType::Signless), 0); return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// AtenTransposeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTransposeIntOp::fold(FoldAdaptor adaptor) { + // first check for no-op + IntegerAttr dim0 = dyn_cast_or_null(adaptor.getDim0()); + IntegerAttr dim1 = dyn_cast_or_null(adaptor.getDim1()); + if (!dim0 || !dim1) + return nullptr; + int64_t _dim0 = dim0.getValue().getSExtValue(); + int64_t _dim1 = dim1.getValue().getSExtValue(); + auto selfTy = dyn_cast(getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return nullptr; + int64_t rank = selfTy.getSizes().size(); + _dim0 = toPositiveDim(_dim0, rank); + _dim1 = toPositiveDim(_dim1, rank); + if (!isValidDim(_dim0, rank) || !isValidDim(_dim1, rank)) + return nullptr; + // if dims are the same, return self + if (_dim0 == _dim1) + return getSelf(); + + // We set a maximum folding size of 16. This is a reasonable upper limit + // for shape computations. + constexpr int64_t kMaxFoldSize = 16; + auto self = dyn_cast_or_null(adaptor.getSelf()); + if (!self || self.getNumElements() > kMaxFoldSize) + return nullptr; + auto resultTy = dyn_cast(getType()); + if (!selfTy || !resultTy || !selfTy.areAllSizesKnown()) + return nullptr; + if (self.isSplat()) + return SplatElementsAttr::get(resultTy.toBuiltinTensor(), + self.getSplatValue()); + + // TODO: add support for rank != 2 + if (rank != 2) + return nullptr; + + ArrayRef sizes = selfTy.getSizes(); + auto values = llvm::to_vector(self.getValues()); + // reordered[i] = Trans[i//sizes[0], i % sizes[0]] = Self[i % sizes[0], + // i//sizes[0]] = values[(i % sizes[0])*sizes[1] + (i//sizes[0])]. + // e.g., Self size = [4,2]; Trans size = [2,4]. + // reindex(i) = (i % 4)*2 + (i // 4) . + // i = 0 -> Trans[0,0] -> Self[0,0] -> 0 . + // i = 1 -> Trans[0,1] -> Self[1,0] -> 2 . + // i = 2 -> Trans[0,2] -> Self[2,0] -> 4 . + // i = 3 -> Trans[0,3] -> Self[3,0] -> 6 . + // i = 4 -> Trans[1,0] -> Self[0,1] -> 1 . + // i = 5 -> Trans[1,1] -> Self[1,1] -> 3 . + auto reindex = [&](int64_t i) { + return (i % sizes[0]) * sizes[1] + (i / sizes[0]); + }; + SmallVector reordered; + for (int64_t i = 0; i < self.getNumElements(); i++) { + reordered.push_back(values[reindex(i)]); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor(), reordered); +} + //===----------------------------------------------------------------------===// // AtenCatOp //===----------------------------------------------------------------------===// @@ -3569,32 +4004,30 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto inType = dyn_cast(getOperand(0).getType()); auto outType = dyn_cast(getResult().getType()); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !inType.hasDtype() || !outType.hasDtype() || + inType.getDtype() != outType.getDtype()) + return nullptr; + if (start && end && step && step.getValue().getSExtValue() == 1 && start.getValue().getSExtValue() == 0 && end.getValue().getSExtValue() == std::numeric_limits::max() && inType == outType) return getOperand(0); - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || - !inType.hasDtype() || !outType.hasDtype() || - inType.getDtype() != outType.getDtype()) - return nullptr; - if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; if (input && input.isSplat()) - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), - input.getSplatValue()); + return DenseElementsAttr::get(outType.toBuiltinTensor(), + input.getSplatValue()); - int count = 1; + int64_t count = 1; for (auto dim : outType.getSizes()) count = count * dim; - if (count == 0) - return {}; + return nullptr; if (!dim) return nullptr; @@ -3602,34 +4035,69 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { if (dimInt < 0) dimInt += inType.getSizes().size(); - bool unaryNonDim = true; - for (int i = 0, s = outType.getSizes().size(); i < s; ++i) - unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt; - // Fold the slice if the output tensor is relatively small, currently // coded to 16: - if (input && start && step && dim && count < 16 && unaryNonDim && - count < 16) { - int64_t inCount = input.getNumElements(); + constexpr int64_t kMaxFold = 16; + if (input && start && step && dim && end && count <= kMaxFold) { int64_t begin = start.getValue().getSExtValue(); - int64_t stride = step.getValue().getSExtValue(); - if (stride < 1) - return {}; int64_t limit = end.getValue().getSExtValue(); - begin = begin < 0 ? begin + inCount : begin; - limit = limit < 0 ? limit + inCount : limit; - limit = limit < 0 ? inType.getSizes()[dimInt] : limit; + int64_t stride = step.getValue().getSExtValue(); + begin = begin < 0 ? begin + inType.getSizes()[dimInt] : begin; + begin = std::max(begin, 0); + limit = limit < 0 ? limit + inType.getSizes()[dimInt] : limit; + limit = limit < 0 ? -1 : limit; limit = std::min(limit, inType.getSizes()[dimInt]); + assert((stride > 0 && begin < limit) || + (stride < 0 && begin > limit) && + "aten.slice.Tensor iteration args are statically invalid."); + + int64_t inputRank = inType.getSizes().size(); + llvm::SmallVector inputStrides(inputRank, 1); + for (int64_t i = inputRank - 2; i >= 0; i--) { + inputStrides[i] = inputStrides[i + 1] * inType.getSizes()[i + 1]; + } llvm::SmallVector values; - for (int i = begin; i < limit; i += stride) - values.push_back(input.getValues()[i]); - - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), values); + values.reserve(count); + auto recursiveIter = [&](auto &self, int64_t currDim, int64_t currOffset) { + if (currDim >= inputRank) + return; + int64_t _stride = (currDim == dimInt) ? stride : 1; + int64_t _begin = (currDim == dimInt) ? begin : 0; + int64_t _limit = (currDim == dimInt) ? limit : inType.getSizes()[currDim]; + // ensure that the limit is reached exactly (even with negative strides) + // E.g., with begin = 0, limit = 10, stride = 3, we modify limit to be 11 + // = 10 + (10-0) % 3 . + // E.g., with begin = 8, limit = -1, stride = -2, limit becomes -2 = -1 + + // (-1-8) % (-2) - stride = -1 + 1 - 2 = -2 . + // Note: cpp uses true math remainder "n % d = least positive int, x, such + // that d divides (n - x)" + int64_t limit_rem = (_limit - _begin) % _stride; + limit_rem = + (_stride > 0 || limit_rem == 0) ? limit_rem : limit_rem - _stride; + _limit += limit_rem; + for (int64_t i = _begin; std::abs(_limit - i) > 0; i += _stride) { + if (currDim == inputRank - 1) { + values.push_back(input.getValues()[currOffset + i]); + } + self(self, currDim + 1, currOffset + inputStrides[currDim] * i); + } + }; + recursiveIter(recursiveIter, 0, 0); + if (static_cast(values.size()) != count) { + emitError( + "Op has incorrect result shape for provided arguments.\nNum elements " + "present in slice: " + + std::to_string(values.size()) + + "\nNum elements implied by result type: " + std::to_string(count)); + return nullptr; + } + return DenseElementsAttr::get(outType.toBuiltinTensor(), values); } - // If the input and output shapes are the same we can just fold: + // If the input and output shapes are the same & step == 1 we can fold: + if (!step || step.getValue().getSExtValue() != 1) + return nullptr; for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; @@ -3645,6 +4113,10 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); + if (lConstant && lhs == 1) + return getOperand(1); + if (rConstant && rhs == 1) + return getOperand(0); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) @@ -3652,6 +4124,45 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenMulIntOp op, PatternRewriter &rewriter) { + int64_t lhs, rhs; + bool lConstant = matchPattern(op.getA(), m_TorchConstantInt(&lhs)); + bool rConstant = matchPattern(op.getB(), m_TorchConstantInt(&rhs)); + if (lConstant && rConstant) + return failure(); + if (lConstant || rConstant) { + int64_t firstConstant = lConstant ? lhs : rhs; + Value firstOperand = lConstant ? op.getB() : op.getA(); + if (firstOperand.getDefiningOp() && + firstOperand.getDefiningOp()) { + auto prevMulIntOp = firstOperand.getDefiningOp(); + int64_t prevLhs, prevRhs; + bool prevLConstant = + matchPattern(prevMulIntOp.getA(), m_TorchConstantInt(&prevLhs)); + bool prevRConstant = + matchPattern(prevMulIntOp.getB(), m_TorchConstantInt(&prevRhs)); + if (prevLConstant && prevRConstant) + return failure(); + if ((prevLConstant || prevRConstant) && + prevMulIntOp->hasOneUse() == 1) { + auto newConstant = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr( + prevLConstant ? prevLhs * firstConstant + : prevRhs * firstConstant)); + rewriter.replaceOpWithNewOp( + op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0), + newConstant); + rewriter.eraseOp(prevMulIntOp); + return success(); + } + } + } + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenMulFloatOp //===----------------------------------------------------------------------===// @@ -3679,7 +4190,7 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (isa(adaptor.getA()) && isa(adaptor.getB())) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a + b; }); @@ -3698,7 +4209,7 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (isa(adaptor.getA()) && isa(adaptor.getB())) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a * b; }); @@ -3708,6 +4219,19 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a * b; }); } +//===----------------------------------------------------------------------===// +// AtenMulIntFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), + [](double a, double b) -> double { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// @@ -3717,7 +4241,7 @@ OpFoldResult AtenSubOp::fold(FoldAdaptor adaptor) { return nullptr; } - if (adaptor.getA().isa() && adaptor.getB().isa()) { + if (isa(adaptor.getA()) && isa(adaptor.getB())) { return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) -> int64_t { return a - b; }); @@ -3754,6 +4278,18 @@ OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](double a, double b) { return a + b; }); } +//===----------------------------------------------------------------------===// +// AtenMulFloatIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulFloatIntOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// @@ -3774,7 +4310,7 @@ OpFoldResult AtenCeilScalarOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } - auto floatValue = adaptor.getA().dyn_cast_or_null(); + auto floatValue = dyn_cast_or_null(adaptor.getA()); if (!floatValue) { return nullptr; } @@ -3802,7 +4338,7 @@ OpFoldResult AtenNegFloatOp::fold(FoldAdaptor adaptor) { if (!adaptor.getA()) { return nullptr; } - auto value = adaptor.getA().dyn_cast_or_null(); + auto value = dyn_cast_or_null(adaptor.getA()); if (!value) { return nullptr; } @@ -3896,7 +4432,7 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); SmallVector data; if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && @@ -3917,7 +4453,7 @@ OpFoldResult AtenTensorIntOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); int64_t data; if (matchPattern(getT(), m_TorchConstantInt(&data))) { @@ -3937,7 +4473,7 @@ OpFoldResult AtenTensorFloatOp::fold(FoldAdaptor adaptor) { if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; Type eTy = resultTy.getDtype(); - ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + ShapedType shapedTy = resultTy.toBuiltinTensor(); double data; if (matchPattern(getT(), m_TorchConstantFloat(&data))) { @@ -3991,6 +4527,42 @@ void AtenIntTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenIntTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { + auto value = adaptor.getA(); + auto dense = dyn_cast_or_null(value); + if (!dense || !dense.isSplat()) { + return nullptr; + } + + auto splat = dense.getSplatValue(); + if (auto intAttr = dyn_cast(splat)) { + auto type = getType(); + if (!isa(type)) { + return nullptr; + } + + if (type.isSignlessInteger()) { + return getI64IntegerAttr(getContext(), intAttr.getInt()); + } else if (type.isSignedInteger()) { + return getI64IntegerAttr(getContext(), intAttr.getSInt()); + } else { + return getI64IntegerAttr(getContext(), intAttr.getUInt()); + } + } + + if (auto floatAttr = dyn_cast(splat)) { + return getI64IntegerAttr( + getContext(), + static_cast(floatAttr.getValue().convertToDouble())); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenFloatTensorOp //===----------------------------------------------------------------------===// @@ -4110,7 +4682,7 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); - auto attrTy = resultTy.toBuiltinTensor().clone(dty); + auto attrTy = resultTy.toBuiltinTensor(); if (auto floatAttr = dyn_cast(splattr)) return DenseElementsAttr::get( attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); @@ -4134,7 +4706,8 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { if (auto intAttr = dyn_cast(splat)) { return intAttr.getType().isUnsignedInteger() ? getI64IntegerAttr(getContext(), intAttr.getUInt()) - : getI64IntegerAttr(getContext(), intAttr.getSInt()); + : getI64IntegerAttr(getContext(), + intAttr.getValue().getSExtValue()); } if (auto floatAttr = dyn_cast(splat)) { return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); @@ -4303,7 +4876,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!valueDense.isSplat()) return nullptr; auto splattr = valueDense.getSplatValue(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, splattr); } @@ -4311,7 +4884,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; int64_t intval = intAttr.getInt(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); } @@ -4319,7 +4892,7 @@ static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { if (!isa(dty)) return nullptr; double dblval = fpAttr.getValueAsDouble(); - auto attrty = ty.toBuiltinTensor().clone(dty); + auto attrty = ty.toBuiltinTensor(); return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); } @@ -4455,8 +5028,8 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { if (getA() == getB()) return getA(); - auto lhs = adaptor.getA().dyn_cast_or_null(); - auto rhs = adaptor.getB().dyn_cast_or_null(); + auto lhs = dyn_cast_or_null(adaptor.getA()); + auto rhs = dyn_cast_or_null(adaptor.getB()); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. @@ -4471,10 +5044,10 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { Attribute a = adaptor.getA(); - auto resultTy = cast(getType()); + auto resultTy = dyn_cast(getType()); if (!a) return {}; - if (!resultTy.hasDtype() || !resultTy.hasSizes()) + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) return {}; auto dty = resultTy.getDtype(); @@ -4524,8 +5097,8 @@ OpFoldResult PrimMinIntOp::fold(FoldAdaptor adaptor) { if (getA() == getB()) return getA(); - auto lhs = adaptor.getA().dyn_cast_or_null(); - auto rhs = adaptor.getB().dyn_cast_or_null(); + auto lhs = dyn_cast_or_null(adaptor.getA()); + auto rhs = dyn_cast_or_null(adaptor.getB()); if (!lhs || !rhs) return nullptr; // Torch semantics are that !torch.int is 64-bit signed. @@ -4612,8 +5185,8 @@ LogicalResult AtenNormScalarOp::verify() { // Check if dtype is one of those supported by norm operation. // ComplexType will match any torch complex types, but each float must be // checked individually. - if (!inTensorDtype.isa()) { + if (!isa(inTensorDtype)) { return emitOpError( "expected a float or complex type for input tensor, but got ") << inTensorDtype; @@ -4622,6 +5195,80 @@ LogicalResult AtenNormScalarOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenRenormOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenRenormOp::verify() { + + auto selfType = cast(getSelf().getType()); + + if (!selfType.hasDtype() || !selfType.hasSizes()) + return success(); + + auto inShape = selfType.getSizes(); + int64_t selfRank = inShape.size(); + auto selfDtype = selfType.getDtype(); + + if (!isa(selfDtype)) + return emitOpError( + "expected a float or complex type for input tensor, but got ") + << selfDtype; + + // According to the Pytoch documentation tensor need to be at least rank 2 + if (selfRank <= 1) + return emitOpError("renorm: input needs at least 2 dimensions, got ") + << selfRank << " dimensions"; + + // Check if argument p is valid + auto pType = getP().getType(); + + if (isa(pType)) + return emitOpError("renorm: p must be real-valued"); + + // The argument 'p' can be either an integer or a floating-point number, + // so we need to consider both options and check if 'p' is within the correct + // range + int64_t pInt = 1; + double_t pDouble = 1; + if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) && + !matchPattern(getP(), m_TorchConstantFloat(&pDouble))) + return success(); + + if (pInt <= 0 || pDouble <= 0) + return emitOpError("renorm: non-positive norm not supported"); + + // Check if argument maxnorm is valid + auto maxnormType = getMaxnorm().getType(); + if (isa(maxnormType)) + return emitOpError("renorm: maxnorm must be real-valued"); + + // The argument 'maxnorm' can be either an integer or a floating-point number, + // so we need to consider both options and check if 'maxnorm' is within the + // correct range + int64_t maxnormInt = 0; + double_t maxnormDouble = 0; + if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) && + !matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble))) + return success(); + + if (maxnormInt < 0 || maxnormDouble < 0) + return emitOpError("renorm: expected maxnorm to be >= 0"); + + // Get the dimension + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) + return success(); + + // check if is dim is in the correct range + if (dim >= selfRank || dim < -selfRank) + return emitOpError("Dimension out of range (expected to be in range of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + + return success(); +} + //===----------------------------------------------------------------------===// // AtenPermuteOp //===----------------------------------------------------------------------===// @@ -4732,18 +5379,86 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { } //===----------------------------------------------------------------------===// -// AtenMaxPool2dWithIndicesOp +// Aten_AssertTensorMetadataOp //===----------------------------------------------------------------------===// -void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( - RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(+[](AtenMaxPool2dWithIndicesOp op, PatternRewriter &rewriter) { +LogicalResult Aten_AssertTensorMetadataOp::fold( + FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { + Value input = getA(); + auto inputType = cast(input.getType()); + if (!inputType.hasDtype() || !inputType.hasSizes()) + return failure(); + + // TODO: Add checks for stride, device, and layout when we can extract that + // information from the torch tensor. For now, we can only get the shape and + // dtype info from the tensor hence adding checks for them. + + // convert size to a list of integers. + SmallVector size; + if (!isa(getSize().getType())) { + if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) { + return emitOpError("expected dtype to be a constant int"); + } + if (!llvm::all_of(llvm::zip(inputType.getSizes(), size), + [](const auto &pair) { + return std::get<0>(pair) == std::get<1>(pair); + })) + return emitOpError("Failed to fold the _assert_tensor_metadata op since " + "the sizes do not match"); + } + + // convert dtype to an integer. + int64_t dtype; + if (!isa(getDtype().getType())) { + if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) { + return emitOpError("expected dtype to be a constant int"); + } + FailureOr inputDtype = + getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype); + if (failed(inputDtype)) + return failure(); + if (inputType.getDtype() != inputDtype) + return emitOpError("Failed to fold the _assert_tensor_metadata op since " + "the dtype does not match"); + } + + getOperation()->erase(); + return success(); +} + +//===----------------------------------------------------------------------===// +// AtenMaxPoolWithIndicesOp +//===----------------------------------------------------------------------===// + +namespace { + +template struct MaxPoolWithoutIndices { + using type = OpTy; +}; + +template <> struct MaxPoolWithoutIndices { + using type = AtenMaxPool2dOp; +}; + +template <> struct MaxPoolWithoutIndices { + using type = AtenMaxPool3dOp; +}; + +} // namespace + +template +struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern { + SimplifyMaxPoolWithIndices(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + LogicalResult + matchAndRewrite(OpTy op, mlir::PatternRewriter &rewriter) const override { if (!op.getResult1().use_empty()) { return rewriter.notifyMatchFailure( - op, "result1 of MaxPool2dWithIndices should be unused"); + op, "result1 of MaxPoolWithIndices should be unused"); } - Value result = rewriter.create( + Value result = rewriter.create::type>( op->getLoc(), op.getResult0().getType(), op.getSelf(), op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), op.getCeilMode()); @@ -4751,6 +5466,30 @@ void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( op.getResult0().replaceAllUsesWith(result); rewriter.eraseOp(op); return success(); + } +}; + +void AtenMaxPool2dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add>(context); +} + +void AtenMaxPool3dWithIndicesOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add>(context); +} + +//===----------------------------------------------------------------------===// +// Aten_AdaptiveAvgPool2dOp +//===----------------------------------------------------------------------===// + +void Aten_AdaptiveAvgPool2dOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](Aten_AdaptiveAvgPool2dOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getOutputSize()); + + return success(); }); } @@ -4844,6 +5583,42 @@ LogicalResult AtenLinalgCrossOp::verify() { return success(); } +LogicalResult AtenKthvalueOp::verify() { + + auto selfType = cast(getSelf().getType()); + + if (!selfType.hasDtype() || !selfType.hasSizes()) + return success(); + + Type selfDtype = selfType.getDtype(); + if (selfDtype.isSignlessInteger(1)) + return emitOpError("input tensors must not have bool dtype"); + + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) + return success(); + + ArrayRef selfShape = selfType.getSizes(); + int64_t selfRank = selfShape.size(); + + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return emitOpError("dim expected to be in range of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + + // convert k to an integer type + int64_t k; + if (!matchPattern(getK(), m_TorchConstantInt(&k))) + return success(); + + // check if k is in the correct range + if (selfShape[dim] != kUnknownSize && (k < 1 || k > selfShape[dim])) + return emitOpError("k expected to be in range of [") + << 1 << ", " << selfShape[dim] << "], but got " << k; + + return success(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// @@ -5007,3 +5782,174 @@ LogicalResult InitializeGlobalSlotsOp::verify() { return emitOpError("expected number of operands to match number of slots"); return success(); } + +//===----------------------------------------------------------------------===// +// BindSymbolicShapeOp +//===----------------------------------------------------------------------===// + +// +// torch.bind_symbolic_shape %6, [%0, %1, %2], affine_map<()[s0, s1, s2] -> +// (s0, s1 * 2 + s2, 3)> : !torch.vtensor<[?,?,3],f32> +// + +ParseResult BindSymbolicShapeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + SmallVector shapeSymbols; + AffineMapAttr shapeExpressions; + Type operandType; + + if (parser.parseOperand(operand) || parser.parseComma() || + parser.parseLSquare() || parser.parseOperandList(shapeSymbols) || + parser.parseRSquare() || parser.parseComma() || + parser.parseAttribute(shapeExpressions, "shape_expressions", + result.attributes) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(operandType)) { + return failure(); + } + + if (parser.resolveOperand(operand, operandType, result.operands) || + parser.resolveOperands(shapeSymbols, + parser.getBuilder().getType(), + result.operands)) { + return failure(); + } + + return success(); +} + +// Use a custom printer here to avoid the AffineMap from getting hoisted +// when printed. This makes it so the AffineMap is printed inline with the op. +void BindSymbolicShapeOp::print(OpAsmPrinter &p) { + p << " " << getOperand() << ", ["; + llvm::interleaveComma(getShapeSymbols(), p); + p << "], " << "affine_map<" << getShapeExpressions().getValue() << ">"; + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"shape_expressions"}); + p << " : " << getOperand().getType(); +} + +LogicalResult BindSymbolicShapeOp::verify() { + if (getShapeSymbols().size() != + getShapeExpressions().getValue().getNumSymbols()) + return emitOpError() + << "requires equal number of shape symbol args and symbol args to " + "the attached affine map, since they are 1:1 mapped"; + + for (auto symbol : getShapeSymbols()) { + Operation *definingOp = symbol.getDefiningOp(); + if (!isa(definingOp)) { + return emitOpError() + << "shape symbol must be produced by a SymbolicIntOp"; + } + } + + return success(); +} +// AtenTriuIndicesOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenTriuIndicesOp::verify() { + + // Check if row, col and offset are constant ints + int64_t row; + if (!matchPattern(getRow(), m_TorchConstantInt(&row))) + return success(); + + int64_t col; + if (!matchPattern(getCol(), m_TorchConstantInt(&col))) + return success(); + + int64_t offset; + if (!matchPattern(getOffset(), m_TorchConstantInt(&offset))) + return success(); + + // Check if values of row, and col are valid + if (row < 0) + return emitOpError("row must be non-negative, got ") << row; + + if (col < 0) + return emitOpError("col must be non-negative, got ") << col; + + // Check if dtype is valid + int64_t dtype; + if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) + return success(); + if (dtype != (int)torch_upstream::ScalarType::Int && + dtype != (int)torch_upstream::ScalarType::Long) + return emitOpError( + "'triu_indices' implemented only for torch.int32 and torch.int64"); + + return success(); +} + +// AtenTrilIndicesOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenTrilIndicesOp::verify() { + + // Check if row, col and offset are constant ints + int64_t row; + if (!matchPattern(getRow(), m_TorchConstantInt(&row))) + return success(); + + int64_t col; + if (!matchPattern(getCol(), m_TorchConstantInt(&col))) + return success(); + + int64_t offset; + if (!matchPattern(getOffset(), m_TorchConstantInt(&offset))) + return success(); + + // Check if values of row, and col are valid + if (row < 0) + return emitOpError("row must be non-negative, got ") << row; + + if (col < 0) + return emitOpError("col must be non-negative, got ") << col; + + // Check if dtype is valid + int64_t dtype; + if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) + return success(); + if (dtype != (int)torch_upstream::ScalarType::Int && + dtype != (int)torch_upstream::ScalarType::Long) + return emitOpError( + "'triu_indices' implemented only for torch.int32 and torch.int64"); + + return success(); +} + +// AtenRot90Op +//===----------------------------------------------------------------------===// + +LogicalResult AtenRot90Op::verify() { + // Check rotation dimensions. + SmallVector dims; + if (!getListConstructElements(getDims(), dims)) + return success(); + + if (dims.size() != 2) + return emitOpError("expected total rotation dims == 2, but got dims = ") + << dims.size(); + + // Check a rank of the input tensor. + auto selfType = cast(getSelf().getType()); + if (!selfType.hasSizes()) + return success(); + + auto selfShape = selfType.getSizes(); + int64_t selfRank = selfShape.size(); + + if (selfRank < 2) + return emitOpError("expected total dims >= 2, but got total dims = ") + << selfRank; + + if (dims[0] == dims[1]) + return emitOpError( + "expected rotation dims to be different, but got dim0 = ") + << dims[0] << " and dim1 = " << dims[1]; + + return success(); +} diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index c162166cdd13..c46865ee5fed 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -185,13 +185,14 @@ static bool isValidTorchDtype(Type dtype) { dtype = cast(dtype).getElementType(); } // Torch quantized types. - if (isa(dtype)) + if (isa(dtype)) return true; // Builtin floating point types. if (isa(dtype)) return true; - if (dtype.isa()) + if (isa(dtype)) return true; if (isa(dtype)) @@ -228,17 +229,29 @@ Type BaseTensorType::getWithSizesAndDtypeFrom(BaseTensorType other) const { Type BaseTensorType::getWithSizesAndDtype( std::optional> optionalSizes, Type optionalDtype) const { - if (isa()) + if (mlir::isa(*this)) return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype); - if (isa()) + if (mlir::isa(*this)) return ValueTensorType::get(getContext(), optionalSizes, optionalDtype); llvm_unreachable("not a BaseTensorType!"); } +Type BaseTensorType::getWithSizesAndDtypeAndSparsity( + std::optional> optionalSizes, Type optionalDtype, + Attribute optionalSparsity) const { + if (mlir::isa(*this)) + return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype, + optionalSparsity); + if (mlir::isa(*this)) + return ValueTensorType::get(getContext(), optionalSizes, optionalDtype, + optionalSparsity); + llvm_unreachable("not a BaseTensorType!"); +} + ValueTensorType BaseTensorType::getWithValueSemantics() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getWithValueSemantics(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor; llvm_unreachable("not a BaseTensorType!"); } @@ -441,12 +454,7 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { } static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { - if (auto floatType = dyn_cast(dtype)) { - return dtype; - } else if (auto integerType = dyn_cast(dtype)) { - return IntegerType::get(context, integerType.getWidth(), - IntegerType::Signless); - } else if (isa(dtype)) { + if (isa(dtype)) { return dtype; } @@ -456,6 +464,9 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { if (isa(dtype)) return IntegerType::get(context, 8, IntegerType::Signless); + if (isa(dtype)) + return IntegerType::get(context, 16, IntegerType::Signless); + if (isa(dtype)) return IntegerType::get(context, 32, IntegerType::Signless); @@ -468,11 +479,11 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { TensorType ValueTensorType::toBuiltinTensor() const { if (!hasDtype()) return nullptr; - if (!hasSizes()) - return UnrankedTensorType::get(getDtype()); Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; + if (!hasSizes()) + return UnrankedTensorType::get(elementType); return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, getOptionalSparsity()); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 43bcc3acc0eb..62f0cb130a04 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" +" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" " %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n" +" %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n" " %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" " %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield\n" @@ -6256,6 +6256,55 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.roi_align\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_align\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.roi_pool\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %3 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_pool\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.nms\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n" +" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n" +" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n" +" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n" +" return %2 : !torch.int\n" +" }\n" +" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n" +" %none = torch.constant.none\n" +" return %none : !torch.none\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.nms\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float) -> !torch.int {\n" +" %int3 = torch.constant.int 3\n" +" return %int3 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n" " %true = torch.constant.bool true\n" @@ -6328,6 +6377,36 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine.tensor_qparams\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_channel_affine_cachemask\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rad2deg\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6408,10 +6487,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exp2\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.expm1\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" " %none = torch.constant.none\n" " %int1 = torch.constant.int 1\n" @@ -6450,6 +6540,116 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0: !torch.list) -> !torch.list {\n" +" %int-2 = torch.constant.int -2\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.list) {\n" +" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield %9 : !torch.list\n" +" } else {\n" +" %9 = torch.derefine %arg0 : !torch.list to !torch.any\n" +" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list\n" +" torch.prim.If.yield %10 : !torch.list\n" +" }\n" +" return %8 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._linalg_det\"(%arg0: !torch.list) -> !torch.tuple, list, list> {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %int-1 = torch.constant.int -1\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = torch.aten.slice.t %arg0, %none, %int-1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %arg0, %1 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %2 : !torch.tuple, list, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._linalg_det\"(%arg0: !torch.tuple) -> !torch.tuple {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %1 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_slogdet\"(%arg0: !torch.list) -> !torch.tuple, list> {\n" +" %int-2 = torch.constant.int -2\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.tuple, list>) {\n" +" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %10 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %11 = torch.prim.TupleConstruct %9, %10 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %11 : !torch.tuple, list>\n" +" } else {\n" +" %9 = torch.derefine %arg0 : !torch.list to !torch.any\n" +" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list\n" +" %11 = torch.prim.TupleConstruct %10, %10 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" torch.prim.If.yield %11 : !torch.tuple, list>\n" +" }\n" +" return %8 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6494,6 +6694,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardtanh_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6514,6 +6718,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.hann_window.periodic\"(%arg0: !torch.int, %arg1: !torch.bool, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg0 : (!torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.hardshrink\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6522,6 +6730,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.polar\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mish\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6575,6 +6787,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._safe_softmax\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.softmax.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6696,12 +6912,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" " %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n" +" %13 = torch.aten.mul.float_int %11, %12 : !torch.float, !torch.int -> !torch.float\n" " %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n" " %15 = torch.aten.append.t %3, %14 : !torch.list, !torch.int -> !torch.list\n" " %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" " %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n" +" %18 = torch.aten.mul.float_int %16, %17 : !torch.float, !torch.int -> !torch.float\n" " %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n" " %20 = torch.aten.append.t %3, %19 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" @@ -6962,6 +7178,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.kthvalue\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = torch.derefine %arg2 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg3) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %1, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7074,6 +7296,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7130,7 +7366,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct : () -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.prims.var\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" " %none = torch.constant.none\n" " %false = torch.constant.bool false\n" " %0 = torch.derefine %none : !torch.none to !torch.any\n" @@ -7206,11 +7442,38 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @__torch__.patched_argmax_shape_func(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__is__ %arg1, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %arg2 : !torch.bool\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg1 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %5 = torch.aten.append.t %3, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmin\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.one_hot\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" @@ -7257,6 +7520,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.amin\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.aminmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.patched_argmax_shape_func(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mean.dim\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.optional to !torch.any\n" " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" @@ -7267,6 +7542,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.sum\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %0 = torch.derefine %arg2 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.prod.dim_int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" " %1 = torch.derefine %0 : !torch.list to !torch.optional>\n" @@ -7351,6 +7632,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.outer\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dot\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.matmul\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7363,6 +7655,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._int_mm\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.mm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.addmm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float, %arg4: !torch.float) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.float to !torch.any\n" " %1 = torch.derefine %arg4 : !torch.float to !torch.any\n" @@ -7457,7 +7753,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.If %2 -> (!torch.list) {\n" " %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" " %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" -" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list \n" +" %7 = torch.aten.mul.left_t %5, %6 : !torch.list, !torch.int -> !torch.list\n" " %8 = torch.aten.add.t %7, %arg1 : !torch.list, !torch.list -> !torch.list\n" " torch.prim.If.yield %8 : !torch.list\n" " } else {\n" @@ -7913,41 +8209,110 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" -" return %arg2 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" -" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" -" return %0 : !torch.list\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool3d_with_indices\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = call @__torch__._max_pool3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %1 : !torch.tuple, list>\n" " }\n" -" func.func @__torch__.pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool) -> !torch.list {\n" -" %int-1 = torch.constant.int -1\n" -" %int-2 = torch.constant.int -2\n" -" %int-3 = torch.constant.int -3\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n" -" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n" -" %true = torch.constant.bool true\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n" +" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n" " %none = torch.constant.none\n" -" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n" -" %int1 = torch.constant.int 1\n" +" %str_1 = torch.constant.str \"AssertionError: Input be of rank 4 or 5\"\n" +" %true = torch.constant.bool true\n" +" %int5 = torch.constant.int 5\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" " %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" " %int2 = torch.constant.int 2\n" -" %int3 = torch.constant.int 3\n" -" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" " %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" -" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.bool) {\n" -" torch.prim.If.yield %true : !torch.bool\n" -" } else {\n" +" %4 = torch.aten.eq.int %3, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.list) {\n" +" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %11, %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" }\n" +" return %10 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" return %arg2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.avg_pool1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.pool1d(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.pool1d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: pool1d: padding must be a single int\"\n" +" %str_1 = torch.constant.str \"AssertionError: pool1d: stride must either be omitted, or a single int\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: pool1d: kernel_size must be a single int\"\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" " %24 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" " %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %25 : !torch.bool\n" @@ -8230,6 +8595,174 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %38 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.avg_pool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.avg_pool3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__.avg_pool3d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %int-4 = torch.constant.int -4\n" +" %int-5 = torch.constant.int -5\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: max_pool3d: padding must either be a single int, or a tuple of thee ints\"\n" +" %str_1 = torch.constant.str \"AssertionError: max_pool3d: stride must either be omitted, a single int, or a tuple of three ints\"\n" +" %none = torch.constant.none\n" +" %str_2 = torch.constant.str \"AssertionError: max_pool3d: kernel_size must either be a single int, or a tuple of three ints\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.tuple) {\n" +" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" } else {\n" +" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" }\n" +" %6:3 = torch.prim.TupleUnpack %5 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13:3 = torch.prim.If %12 -> (!torch.int, !torch.int, !torch.int) {\n" +" torch.prim.If.yield %6#0, %6#0, %6#0 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %38 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %40:3 = torch.prim.If %39 -> (!torch.int, !torch.int, !torch.int) {\n" +" %41 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %41, %42, %43 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %41 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %41, %42, %43 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" torch.prim.If.yield %40#0, %40#1, %40#2 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.tuple) {\n" +" %38 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" } else {\n" +" %38 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.__getitem__.t %arg3, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %41 = torch.prim.TupleConstruct %38, %39, %40 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %41 : !torch.tuple\n" +" }\n" +" %20:3 = torch.prim.TupleUnpack %19 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %21 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %38 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %39 = torch.aten.eq.int %38, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %39 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.int) {\n" +" %38 = torch.aten.__getitem__.t %arg0, %int-5 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %38 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %27 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %29 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %31 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%28, %6#0, %20#0, %13#0, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %32 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%29, %6#1, %20#1, %13#1, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %33 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%30, %6#2, %20#2, %13#2, %int1, %arg4) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %34 = call @__torch__._pool3d_shape_check(%arg0, %6#0, %6#1, %6#2, %13#0, %13#1, %13#2, %20#0, %20#1, %20#2, %int1, %int1, %int1, %31, %32, %33) : (!torch.list, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n" +" %35 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %36 = torch.aten.eq.int %35, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %37 = torch.prim.If %36 -> (!torch.list) {\n" +" %38 = torch.prim.ListConstruct %27, %31, %32, %33 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %38 : !torch.list\n" +" } else {\n" +" %38 = torch.prim.ListConstruct %26, %27, %31, %32, %33 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %38 : !torch.list\n" +" }\n" +" return %37 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8350,7 +8883,113 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.int) -> !torch.list {\n" +" %int3 = torch.constant.int 3\n" +" %int-1 = torch.constant.int -1\n" +" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n" +" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: \"\n" +" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %23 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n" +" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list) -> !torch.list\n" +" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list) -> !torch.list\n" +" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list, !torch.list, !torch.list) -> !torch.list>\n" +" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list, !torch.list, !torch.list) -> !torch.list>\n" +" torch.prim.Loop %int3, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" %25 = torch.aten.len.t %24 : !torch.list -> !torch.int\n" +" %26 = torch.aten.len.t %23 : !torch.list -> !torch.int\n" +" torch.prim.Loop %26, %true, init() {\n" +" ^bb0(%arg9: !torch.int):\n" +" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list, !torch.int -> !torch.int\n" +" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %28 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list -> !torch.str\n" +" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list>, !torch.int -> !torch.list\n" +" torch.aten.insert.t %29, %27, %int1 : !torch.list, !torch.int, !torch.int\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.aten.len.t %5 : !torch.list -> !torch.int\n" +" %11 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %23 = torch.aten.len.t %6 : !torch.list -> !torch.int\n" +" %24 = torch.aten.len.t %7 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list, !torch.list, !torch.list) -> !torch.list\n" +" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list\n" +" %16 = torch.aten.len.t %14 : !torch.list -> !torch.int\n" +" %17 = torch.aten.mul.left_t %15, %16 : !torch.list, !torch.int -> !torch.list\n" +" %18 = torch.aten.len.t %arg6 : !torch.list -> !torch.int\n" +" torch.prim.Loop %18, %true, init() {\n" +" ^bb0(%arg8: !torch.int):\n" +" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list, !torch.int -> !torch.int\n" +" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list, !torch.int, !torch.bool -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %19 = torch.aten.len.t %14 : !torch.list -> !torch.int\n" +" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.prim.Loop %21, %true, init(%14) {\n" +" ^bb0(%arg8: !torch.int, %arg9: !torch.list):\n" +" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list, !torch.int -> !torch.bool\n" +" %25 = torch.prim.If %24 -> (!torch.list) {\n" +" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list, !torch.int, !torch.bool) -> !torch.list\n" +" torch.prim.If.yield %26 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %arg9 : !torch.list\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%25 : !torch.list)\n" +" } : (!torch.int, !torch.bool, !torch.list) -> !torch.list\n" +" return %22 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.list {\n" " %int-1 = torch.constant.int -1\n" " %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list, !torch.int -> !torch.int\n" " %1 = torch.aten._set_item.t %arg0, %int-1, %0 : !torch.list, !torch.int, !torch.int -> !torch.list\n" @@ -8494,6 +9133,64 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" " return %13 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rot90\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"expected total rotation dims == 2, but got dims = {}\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %str_1 = torch.constant.str \"expected total dims >= 2 but got {}\"\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %int1 = torch.constant.int 1\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %9 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %10 = torch.aten.format(%str_1, %9) : !torch.str, !torch.int -> !torch.str\n" +" %11 = torch.aten.add.str %str_0, %10 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %11, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %9 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %10 = torch.aten.format(%str, %9) : !torch.str, !torch.int -> !torch.str\n" +" %11 = torch.aten.add.str %str_0, %10 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %11, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.remainder.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.add.int %4, %int4 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.aten.remainder.int %5, %int4 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %9 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %9 : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" %9 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg0, %11 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten._set_item.t %arg0, %13, %10 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %15 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten._set_item.t %arg0, %15, %12 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8544,30 +9241,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.bernoulli\"(%arg0: !torch.list, %arg1: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" -" return %arg2 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.randint\"(%arg0: !torch.int, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" -" return %arg1 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.randn\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.multinomial\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %6, %arg1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %7 : !torch.list\n" +" }\n" +" return %5 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randn_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" return %arg2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randint\"(%arg0: !torch.int, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randn\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" " %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" @@ -8634,6 +9368,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.frac\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.signbit\"(%arg0: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.ldexp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.copysign.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.__and__.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8650,6 +9398,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fmin\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fmax\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.bitwise_or.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8833,6 +9589,248 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.col2im\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Expected size of input's dimension 2 to match the calculated number of sliding blocks\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected size of input's dimension 1 to be divisible by the product of kernel_size\"\n" +" %int-1 = torch.constant.int -1\n" +" %str_1 = torch.constant.str \"AssertionError: stride must be greater than 0\"\n" +" %str_2 = torch.constant.str \"AssertionError: padding should be non negative\"\n" +" %str_3 = torch.constant.str \"AssertionError: dilation should be greater than 0\"\n" +" %str_4 = torch.constant.str \"AssertionError: kernel size should be greater than 0\"\n" +" %str_5 = torch.constant.str \"AssertionError: padding is expected to have length 2\"\n" +" %str_6 = torch.constant.str \"AssertionError: stride is expected to have length 2\"\n" +" %str_7 = torch.constant.str \"AssertionError: dilation is expected to have length 2\"\n" +" %str_8 = torch.constant.str \"AssertionError: kernel_size is expected to have length 2\"\n" +" %str_9 = torch.constant.str \"AssertionError: output_size is expected to have length 2\"\n" +" %none = torch.constant.none\n" +" %str_10 = torch.constant.str \"AssertionError: Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non zero dimensions for input\"\n" +" %true = torch.constant.bool true\n" +" %false = torch.constant.bool false\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %75 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.ne.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %76 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %75 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.ne.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %76 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %75 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %76 = torch.prim.If %75 -> (!torch.bool) {\n" +" %78 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %79 = torch.aten.ne.int %78, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %79 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %77 = torch.prim.If %76 -> (!torch.bool) {\n" +" %78 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %79 = torch.aten.ne.int %78, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %79 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If.yield %77 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_10, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %6 = torch.aten.eq.int %5, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_9, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_8, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %10 = torch.aten.eq.int %9, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_7, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg5 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_6, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %14 = torch.aten.eq.int %13, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %14 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_5, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %15 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.gt.int %15, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %75 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %76 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %17 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_4, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %19 = torch.aten.gt.int %18, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %20 = torch.prim.If %19 -> (!torch.bool) {\n" +" %75 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %76 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %20 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %21 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.ge.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %75 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.ge.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %76 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.__getitem__.t %arg5, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.gt.int %24, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" %75 = torch.aten.__getitem__.t %arg5, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.gt.int %75, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %76 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %27 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %28 = torch.prim.If %27 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int-1 : !torch.int\n" +" }\n" +" %29 = torch.aten.add.int %28, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %30 = torch.aten.__getitem__.t %arg0, %29 : !torch.list, !torch.int -> !torch.int\n" +" %31 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %32 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %33 = torch.aten.mul.int %31, %32 : !torch.int, !torch.int -> !torch.int\n" +" %34 = torch.aten.remainder.int %30, %33 : !torch.int, !torch.int -> !torch.int\n" +" %35 = torch.aten.eq.int %34, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %35 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %36 = torch.aten.add.int %28, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %37 = torch.aten.__getitem__.t %arg0, %36 : !torch.list, !torch.int -> !torch.int\n" +" %38 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %39 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %40 = torch.aten.mul.int %int2, %39 : !torch.int, !torch.int -> !torch.int\n" +" %41 = torch.aten.add.int %38, %40 : !torch.int, !torch.int -> !torch.int\n" +" %42 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %43 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %44 = torch.aten.sub.int %43, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %45 = torch.aten.mul.int %42, %44 : !torch.int, !torch.int -> !torch.int\n" +" %46 = torch.aten.sub.int %41, %45 : !torch.int, !torch.int -> !torch.int\n" +" %47 = torch.aten.sub.int %46, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %48 = torch.aten.__getitem__.t %arg5, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %49 = torch.aten.floordiv.int %47, %48 : !torch.int, !torch.int -> !torch.int\n" +" %50 = torch.aten.add.int %49, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %51 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %52 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %53 = torch.aten.mul.int %int2, %52 : !torch.int, !torch.int -> !torch.int\n" +" %54 = torch.aten.add.int %51, %53 : !torch.int, !torch.int -> !torch.int\n" +" %55 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %56 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %57 = torch.aten.sub.int %56, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %58 = torch.aten.mul.int %55, %57 : !torch.int, !torch.int -> !torch.int\n" +" %59 = torch.aten.sub.int %54, %58 : !torch.int, !torch.int -> !torch.int\n" +" %60 = torch.aten.sub.int %59, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %61 = torch.aten.__getitem__.t %arg5, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %62 = torch.aten.floordiv.int %60, %61 : !torch.int, !torch.int -> !torch.int\n" +" %63 = torch.aten.add.int %62, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %64 = torch.aten.mul.int %50, %63 : !torch.int, !torch.int -> !torch.int\n" +" %65 = torch.aten.eq.int %37, %64 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %65 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %66 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %67 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %68 = torch.aten.mul.int %66, %67 : !torch.int, !torch.int -> !torch.int\n" +" %69 = torch.aten.floordiv.int %30, %68 : !torch.int, !torch.int -> !torch.int\n" +" %70 = torch.aten.eq.int %28, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %71 = torch.prim.If %70 -> (!torch.list) {\n" +" %75 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.prim.ListConstruct %75, %69 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %76 : !torch.list\n" +" } else {\n" +" %75 = torch.prim.ListConstruct %69 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %75 : !torch.list\n" +" }\n" +" %72 = torch.prim.ListConstruct : () -> !torch.list\n" +" %73 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %73, %true, init() {\n" +" ^bb0(%arg6: !torch.int):\n" +" %75 = torch.aten.__getitem__.t %arg1, %arg6 : !torch.list, !torch.int -> !torch.int\n" +" %76 = torch.aten.append.t %72, %75 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %74 = torch.aten.add.t %71, %72 : !torch.list, !torch.list -> !torch.list\n" +" return %74 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -9022,14 +10020,85 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.deform_conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg2, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.deform_conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.tuple, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" +" func.func @__torch__._conv_padding(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int-1 = torch.constant.int -1\n" +" %str = torch.constant.str \"same\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: conv: weight must be at least 3 dimensional.\"\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.sub.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" +" %4 = torch.aten.mul.left_t %3, %2 : !torch.list, !torch.int -> !torch.list\n" +" %5 = torch.aten.eq.str %arg2, %str : !torch.str, !torch.str -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" %6 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %8 = torch.aten.__range_length %6, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.prim.min.self_int %9 : !torch.list -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__derive_index %arg3, %6, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.add.int %int2, %12 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg0, %13 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.mul.int %11, %15 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.floordiv.int %16, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten._set_item.t %4, %12, %17 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv3d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.list to !torch.optional>\n" " %1 = torch.derefine %arg4 : !torch.list to !torch.optional>\n" @@ -9099,6 +10168,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv1d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %false, %1, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_transpose3d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list {\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9133,25 +10220,92 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" -" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %0 : !torch.tuple, list>\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.sort\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" -" %int4 = torch.constant.int 4\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" -" return %1 : !torch.tuple\n" +" func.func @\"__torch_mlir_shape_fn.aten._weight_norm_interface\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" +" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %2 : !torch.tuple, list>\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.narrow\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" -" %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int\n" -" %1 = torch.derefine %arg2 : !torch.int to !torch.optional\n" -" %2 = torch.derefine %0 : !torch.int to !torch.optional\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" %9 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %9 = func.call @__torch__.torch.jit._shape_functions.max_int() : () -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" %4 = torch.aten.lt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5:3 = torch.prim.If %4 -> (!torch.int, !torch.int, !torch.int) {\n" +" %9 = torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %1, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1 : !torch.int\n" +" }\n" +" %11 = torch.aten.lt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" %13 = torch.aten.lt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %3, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" %15 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %int-1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %17 = torch.aten.add.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.neg.int %arg4 : !torch.int -> !torch.int\n" +" torch.prim.If.yield %17, %18, %19 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1, %3, %arg4 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %6 = torch.derefine %5#0 : !torch.int to !torch.optional\n" +" %7 = torch.derefine %5#1 : !torch.int to !torch.optional\n" +" %8 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %6, %7, %5#2) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" +" return %8 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg1 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" +" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.sort\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.narrow\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.add.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %1 = torch.derefine %arg2 : !torch.int to !torch.optional\n" +" %2 = torch.derefine %0 : !torch.int to !torch.optional\n" " %3 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %1, %2, %int1) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %3 : !torch.list\n" " }\n" @@ -9181,6 +10335,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.scatter.value\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_select\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9278,6 +10435,119 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4, %arg6, %0) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int, !torch.optional>, !torch.optional) -> !torch.tuple, list, list, list>\n" " return %1 : !torch.tuple, list, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.sub.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %4 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %6:2 = torch.prim.If %5 -> (!torch.int, !torch.int) {\n" +" torch.prim.If.yield %int0, %int0 : !torch.int, !torch.int\n" +" } else {\n" +" %11 = torch.aten.gt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" %27 = torch.aten.add.int %int1, %3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.prim.min.int %arg1, %27 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %28 : !torch.int\n" +" } else {\n" +" %27 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.gt.int %27, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %29 = torch.aten.Int.bool %28 : !torch.bool -> !torch.int\n" +" torch.prim.If.yield %29 : !torch.int\n" +" }\n" +" %13 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.prim.min.int %arg1, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.prim.max.int %int0, %14 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.add.int %arg0, %3 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.prim.min.int %arg0, %16 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.sub.int %15, %12 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.mul.int %21, %20 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.floordiv.int %22, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %24 = torch.aten.sub.int %18, %20 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.mul.int %24, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.prim.max.int %int0, %25 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %23, %26 : !torch.int, !torch.int\n" +" }\n" +" %7 = torch.aten.mul.int %arg0, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %6#0, %6#1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten.sub.int %7, %8 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.prim.ListConstruct %int2, %9 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %10 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tril_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %21 = torch.aten.add.int %int1, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.prim.min.int %arg1, %21 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %22 : !torch.int\n" +" } else {\n" +" %21 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %22 = torch.aten.gt.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.aten.Int.bool %22 : !torch.bool -> !torch.int\n" +" torch.prim.If.yield %23 : !torch.int\n" +" }\n" +" %5 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.min.int %arg1, %5 : !torch.int, !torch.int -> !torch.int\n" +" %7 = torch.prim.max.int %int0, %6 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.prim.min.int %arg0, %8 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.prim.max.int %int0, %9 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.sub.int %7, %4 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.add.int %11, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.mul.int %13, %12 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.floordiv.int %14, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.sub.int %10, %12 : !torch.int, !torch.int -> !torch.int\n" +" %17 = torch.aten.mul.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.prim.ListConstruct %int2, %19 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %20 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.deg2rad\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -9298,10 +10568,34 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.l1_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" } else {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" }\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.cross_entropy_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.optional>, !torch.int, !torch.int, !torch.float) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.int) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = torch.aten.eq.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.native_layer_norm(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list, list>\n" " return %0 : !torch.tuple, list, list>\n" @@ -9311,14 +10605,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" -" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" " %true = torch.constant.bool true\n" -" %str = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" " %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" +" %str_1 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" " %int2 = torch.constant.int 2\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" @@ -9328,7 +10625,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" @@ -9338,18 +10635,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" " %8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" " torch.prim.Loop %8, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" ^bb0(%arg3: !torch.int):\n" +" torch.prim.If %arg2 -> () {\n" +" %20 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg1, %20 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.neg.int %22 : !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %arg0, %23 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.lt.int %21, %24 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" %27 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.add.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %29 = torch.aten.__getitem__.t %arg1, %28 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.neg.int %30 : !torch.int -> !torch.int\n" +" %32 = torch.aten.__getitem__.t %arg0, %31 : !torch.list, !torch.int -> !torch.int\n" +" %33 = torch.aten.lt.int %29, %32 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %33 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" " %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int\n" -" %11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" " %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" " %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n" @@ -9361,6 +10687,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: \"\n" @@ -9382,7 +10709,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" @@ -9390,17 +10717,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" -" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %false = torch.constant.bool false\n" -" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" " %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" @@ -9409,37 +10734,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.bool) {\n" -" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %7 : !torch.list\n" +" %2 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" -" %int-1 = torch.constant.int -1\n" -" %int-2 = torch.constant.int -2\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" " %int4 = torch.constant.int 4\n" -" %int0 = torch.constant.int 0\n" -" %int3 = torch.constant.int 3\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" " %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" @@ -9448,48 +10752,42 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %11 -> () {\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad3d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int6 = torch.constant.int 6\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %13 -> () {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %14 : !torch.list\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" @@ -9623,27 +10921,274 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.cat(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atleast_1d\"(%arg0: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atleast_2d\"(%arg0: !torch.list) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.prim.ListConstruct %int1, %6 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %7 : !torch.list\n" +" } else {\n" +" torch.prim.If.yield %arg0 : !torch.list\n" +" }\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.stack\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.stack(%arg0, %arg1) : (!torch.list>, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" -" return %arg0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" -" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" -" return %1 : !torch.list\n" -" }\n" -" func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n" -" %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n" -" %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n" -" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n" -" return %2 : !torch.int\n" +" func.func @\"__torch_mlir_shape_fn.aten.hstack\"(%arg0: !torch.list>) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list\n" +" %7 = func.call @\"__torch_mlir_shape_fn.aten.atleast_1d\"(%6) : (!torch.list) -> !torch.list\n" +" %8 = torch.aten.append.t %0, %7 : !torch.list>, !torch.list -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %2 = torch.aten.__getitem__.t %0, %int0 : !torch.list>, !torch.int -> !torch.list\n" +" %3 = torch.aten.len.t %2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int0) : (!torch.list>, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" %6 = func.call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list>, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" }\n" +" return %5 : !torch.list\n" " }\n" -" func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n" +" func.func @\"__torch_mlir_shape_fn.aten.column_stack\"(%arg0: !torch.list>) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %1, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %3 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list\n" +" %4 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.list) {\n" +" %8 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %8 : !torch.list\n" +" } else {\n" +" %8 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %9 -> () {\n" +" %10 = torch.aten.append.t %3, %int1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %3 : !torch.list\n" +" }\n" +" %7 = torch.aten.append.t %0, %6 : !torch.list>, !torch.list -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %2 = call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list>, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_rfft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" %true = torch.constant.bool true\n" " %none = torch.constant.none\n" -" return %none : !torch.none\n" +" %str = torch.constant.str \"AssertionError: Expected dim in [-rank, rank-1]\"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.add.int %arg2, %10 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %11 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" %2 = torch.aten.ge.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %11 = torch.aten.lt.int %1, %10 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.prim.ListConstruct : () -> !torch.list\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.append.t %4, %10 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.add.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %9 = torch.aten._set_item.t %4, %1, %8 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional, %arg8: !torch.optional) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n" +" %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Expected input tensor to be of shape (B?,L), where B is an optional batch dimension\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.optional) {\n" +" %24 = torch.derefine %none : !torch.none to !torch.optional\n" +" torch.prim.If.yield %24 : !torch.optional\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.derefine %24 : !torch.int to !torch.optional\n" +" torch.prim.If.yield %25 : !torch.optional\n" +" }\n" +" %6 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %6, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.int) {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" %24 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %9 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %24 = torch.aten.floordiv.int %arg1, %int4 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" } else {\n" +" %24 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %24 : !torch.int\n" +" }\n" +" %11 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %24 = torch.aten.le.int %arg1, %8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %24 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.gt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = torch.prim.ListConstruct : () -> !torch.list\n" +" %15 = torch.aten.__isnot__ %5, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %15 -> () {\n" +" %24 = torch.prim.unchecked_cast %5 : !torch.optional -> !torch.int\n" +" %25 = torch.aten.append.t %14, %24 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %16 = torch.aten.__is__ %arg6, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %24 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.bool\n" +" %25 = torch.aten.eq.bool %24, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" }\n" +" torch.prim.If %17 -> () {\n" +" %24 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %25 = torch.aten.add.int %24, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %26 = torch.aten.append.t %14, %25 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" %24 = torch.aten.append.t %14, %arg1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" %18 = torch.aten.sub.int %8, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.floordiv.int %18, %10 : !torch.int, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %int1, %19 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.append.t %14, %20 : !torch.list, !torch.int -> !torch.list\n" +" %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %24 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %25 = torch.aten.eq.bool %24, %false : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %25 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" %24 = torch.aten.append.t %14, %int2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" return %14 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fft_ifft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.bincount\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" +" %1 = torch.prim.ListConstruct %0 : (!torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.nonzero\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.hacky_get_unknown_dimension_size() : () -> !torch.int\n" @@ -9678,6 +11223,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" @@ -9693,6 +11241,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg3, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.prim.ListConstruct %0, %1, %2 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %3 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest1d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.prim.Uninitialized : !torch.optional>\n" +" %2 = torch.prim.Uninitialized : !torch.optional>\n" +" %3 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %5:2 = torch.prim.If %4 -> (!torch.optional>, !torch.optional>) {\n" +" torch.prim.If.yield %arg1, %arg2 : !torch.optional>, !torch.optional>\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %1, %2 : !torch.optional>, !torch.optional>\n" +" }\n" +" %6 = torch.aten.__is__ %5#0, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.aten.__is__ %5#1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.__isnot__ %5#0, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.list) {\n" +" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.prim.ListConstruct %12, %13, %14 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" } else {\n" +" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.list) {\n" +" %20 = torch.prim.unchecked_cast %5#1 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %20 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" %13 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" +" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" +" %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %19 : !torch.list\n" +" }\n" +" return %10 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" @@ -9703,11 +11328,188 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.prim.Uninitialized : !torch.optional>\n" +" %2 = torch.prim.Uninitialized : !torch.optional>\n" +" %3 = torch.aten.__is__ %arg1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.prim.unchecked_cast %arg1 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__is__ %arg2, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" %5:2 = torch.prim.If %4 -> (!torch.optional>, !torch.optional>) {\n" +" torch.prim.If.yield %arg1, %arg2 : !torch.optional>, !torch.optional>\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %1, %2 : !torch.optional>, !torch.optional>\n" +" }\n" +" %6 = torch.aten.__is__ %5#0, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.bool) {\n" +" %11 = torch.aten.__is__ %5#1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %8 = torch.aten.__not__ %7 : !torch.bool -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.__isnot__ %5#0, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.list) {\n" +" %11 = torch.prim.unchecked_cast %5#0 : !torch.optional> -> !torch.list\n" +" %12 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.prim.ListConstruct %12, %13, %14, %15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %16 : !torch.list\n" +" } else {\n" +" %11 = torch.aten.__isnot__ %5#1, %none : !torch.optional>, !torch.none -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.list) {\n" +" %24 = torch.prim.unchecked_cast %5#1 : !torch.optional> -> !torch.list\n" +" torch.prim.If.yield %24 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" %13 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list, !torch.int -> !torch.float\n" +" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" +" %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" +" %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list, !torch.int -> !torch.float\n" +" %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n" +" %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n" +" %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %23 : !torch.list\n" +" }\n" +" return %10 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.list {\n" +" %0 = call @\"__torch_mlir_shape_fn.aten.upsample_nearest2d.vec\"(%arg0, %arg1, %arg3) : (!torch.list, !torch.optional>, !torch.optional>) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.split_dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" +" %int7 = torch.constant.int 7\n" +" %int6 = torch.constant.int 6\n" +" %int15 = torch.constant.int 15\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %3 = torch.prim.TupleConstruct %2, %int11 : !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine.tensor_qparams\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" %int15 = torch.constant.int 15\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple {\n" +" %int11 = torch.constant.int 11\n" +" %int15 = torch.constant.int 15\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %0#1 : !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple\n" +" return %4 : !torch.tuple\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.int {\n" " %int15 = torch.constant.int 15\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -9728,18 +11530,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() : () -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" -" }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_float_dtypes() -> !torch.list {\n" -" %int7 = torch.constant.int 7\n" -" %int6 = torch.constant.int 6\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_channel_affine_cachemask\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int) -> !torch.tuple {\n" +" %int11 = torch.constant.int 11\n" " %int15 = torch.constant.int 15\n" -" %int5 = torch.constant.int 5\n" -" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple\n" +" return %4 : !torch.tuple\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -9785,11 +11599,61 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exp2\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rad2deg\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %4 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %4 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" +" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" +" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" +" return %1 : !torch.bool\n" +" }\n" +" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %int11 = torch.constant.int 11\n" +" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.sin\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -9870,6 +11734,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hann_window.periodic\"(%arg0: !torch.int, %arg1: !torch.bool, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.hardshrink\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" @@ -9887,6 +11771,44 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.polar\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n" +" } else {\n" +" %7 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n" +" }\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -9913,21 +11835,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%arg0: !torch.int) -> !torch.bool {\n" -" %0 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() : () -> !torch.list\n" -" %1 = torch.aten.__contains__.int_list %0, %arg0 : !torch.list, !torch.int -> !torch.bool\n" -" return %1 : !torch.bool\n" -" }\n" -" func.func @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.all_integer_dtypes() -> !torch.list {\n" -" %int4 = torch.constant.int 4\n" -" %int3 = torch.constant.int 3\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" -" %int11 = torch.constant.int 11\n" -" %0 = torch.prim.ListConstruct %int11, %int0, %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.frobenius_norm.dim\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" @@ -9978,6 +11885,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %1#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" @@ -10015,7 +11935,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" @@ -10027,11 +11957,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" @@ -10071,6 +12021,63 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._weight_norm_interface\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.tuple {\n" +" %int15 = torch.constant.int 15\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.tuple\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.aten.__not__ %4 : !torch.bool -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %2#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %7:2 = torch.prim.If %6 -> (!torch.bool, !torch.tuple) {\n" +" %9 = torch.prim.TupleConstruct %1#1, %int7 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %true, %9 : !torch.bool, !torch.tuple\n" +" } else {\n" +" %9 = torch.aten.eq.int %2#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.tuple) {\n" +" %11 = torch.prim.TupleConstruct %1#1, %int6 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %true, %11 : !torch.bool, !torch.tuple\n" +" } else {\n" +" %11 = torch.aten.eq.int %2#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.tuple) {\n" +" %13 = torch.prim.TupleConstruct %1#1, %int6 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %true, %13 : !torch.bool, !torch.tuple\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.tuple\n" +" }\n" +" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.tuple\n" +" }\n" +" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.tuple\n" +" }\n" +" %8 = torch.prim.If %7#0 -> (!torch.tuple) {\n" +" torch.prim.If.yield %7#1 : !torch.tuple\n" +" } else {\n" +" %9 = torch.prim.TupleConstruct %1#1, %2#1 : !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %9 : !torch.tuple\n" +" }\n" +" return %8 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10083,6 +12090,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.multinomial\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.any) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" return %int4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bitwise_not\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10258,7 +12269,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.cumsum\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumsum\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %none = torch.constant.none\n" " %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" @@ -10281,6 +12311,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_det\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10433,6 +12477,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_backward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number, %arg5: !torch.bool, %arg6: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lift_fresh_copy\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10462,6 +12514,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" " return %6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.kthvalue\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" @@ -10499,6 +12557,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool3d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -10589,6 +12657,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" @@ -10667,6 +12768,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.masked_scatter\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10683,9 +12788,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_slogdet\"(%arg0: !torch.tuple) -> !torch.tuple {\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %true = torch.constant.bool true\n" +" %int8 = torch.constant.int 8\n" +" %false = torch.constant.bool false\n" +" %int15 = torch.constant.int 15\n" +" %int5 = torch.constant.int 5\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.ne.int %0#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %11 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.eq.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" %8 = torch.aten.eq.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int7 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" %10 = torch.prim.TupleConstruct %0#1, %9 : !torch.int, !torch.int -> !torch.tuple\n" +" return %10 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.square\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " %int11 = torch.constant.int 11\n" @@ -10782,10 +12946,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest1d.vec\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_nearest2d.vec\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.upsample_bilinear2d.vec\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.bool, %arg3: !torch.optional>) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.view\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10901,7 +13085,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" @@ -10933,70 +13117,289 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" +" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %8 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %11 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_rfft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int10 = torch.constant.int 10\n" +" %int7 = torch.constant.int 7\n" +" %int9 = torch.constant.int 9\n" +" %int6 = torch.constant.int 6\n" +" %int8 = torch.constant.int 8\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" torch.prim.If.yield %int8 : !torch.int\n" +" } else {\n" +" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" torch.prim.If.yield %int10 : !torch.int\n" +" } else {\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %int9 : !torch.int\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional, %arg8: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" +" %int7 = torch.constant.int 7\n" +" %int10 = torch.constant.int 10\n" +" %int6 = torch.constant.int 6\n" +" %int9 = torch.constant.int 9\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" " %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" " %false = torch.constant.bool false\n" -" %int9 = torch.constant.int 9\n" -" %int10 = torch.constant.int 10\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %3 : !torch.bool\n" +" %true = torch.constant.bool true\n" +" %0 = torch.prim.Uninitialized : !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %7 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" " }\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" %7 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %false : !torch.bool\n" " }\n" -" return %int11 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %false = torch.constant.bool false\n" -" %int9 = torch.constant.int 9\n" -" %int10 = torch.constant.int 10\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" -" %2 = torch.prim.If %1 -> (!torch.bool) {\n" -" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %3 : !torch.bool\n" +" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n" " } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %11 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %11 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %12 = torch.aten.ne.bool %11, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %10:2 = torch.prim.If %9 -> (!torch.bool, !torch.int) {\n" +" %11 = torch.aten.eq.int %1#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %12:2 = torch.prim.If %11 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int5 : !torch.bool, !torch.int\n" +" } else {\n" +" %13 = torch.aten.eq.int %1#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int6 : !torch.bool, !torch.int\n" +" } else {\n" +" %15 = torch.aten.eq.int %1#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int7 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %12#0, %12#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %11 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.bool) {\n" +" %15 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %14:2 = torch.prim.If %13 -> (!torch.bool, !torch.int) {\n" +" %15 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %16:2 = torch.prim.If %15 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int8 : !torch.bool, !torch.int\n" +" } else {\n" +" %17 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" %19 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %16#0, %16#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %15 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %19 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If.yield %19 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %17 = torch.prim.If %16 -> (!torch.bool) {\n" +" %19 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" +" %20 = torch.aten.ne.bool %19, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %18:2 = torch.prim.If %17 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %1#1 : !torch.bool, !torch.int\n" +" } else {\n" +" %19 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %20:2 = torch.prim.If %19 -> (!torch.bool, !torch.int) {\n" +" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %20#0, %20#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %18#0, %18#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %14#0, %14#1 : !torch.bool, !torch.int\n" +" }\n" +" torch.prim.If.yield %10#0, %10#1 : !torch.bool, !torch.int\n" " }\n" -" torch.prim.If %2 -> () {\n" -" torch.prim.If.yield\n" +" %6 = torch.prim.If %5#0 -> (!torch.int) {\n" +" torch.prim.If.yield %5#1 : !torch.int\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" +" torch.prim.If.yield %0 : !torch.int\n" " }\n" -" return %int11 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.ne.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" -" %int11 = torch.constant.int 11\n" -" return %int11 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.add\"(%arg0: !torch.number, %arg1: !torch.number) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %0 = torch.prim.ListConstruct %none, %none : (!torch.none, !torch.none) -> !torch.list>\n" -" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n" -" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n" -" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %4 : !torch.int\n" +" return %6 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_fft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fft_ifft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int10 = torch.constant.int 10\n" @@ -11049,6 +13452,93 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.frac\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.signbit\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.ldexp.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %false = torch.constant.bool false\n" +" %int7 = torch.constant.int 7\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %1#1 : !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %10 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %10 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %13 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %13 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %13 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %13 : !torch.int\n" +" }\n" +" torch.prim.If.yield %12 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.copysign.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %false = torch.constant.bool false\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %7 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.__and__.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -11437,6 +13927,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dot\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %1#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.matmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -11466,6 +13970,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmax\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmin\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.outer\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int5 = torch.constant.int 5\n" @@ -11484,17 +14012,40 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %7 = torch.aten.ne.int %1#1, %0#1 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If.yield %7 : !torch.bool\n" " } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %7 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %8 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %9 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" return %6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._int_mm\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %int3 = torch.constant.int 3\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int1 = torch.constant.int 1\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" " }\n" -" %6 = torch.prim.If %5 -> (!torch.int) {\n" -" torch.prim.If.yield %int5 : !torch.int\n" +" %3 = torch.aten.eq.int %1#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" " } else {\n" -" %7 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %8 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %9 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%7, %8) : (!torch.list>, !torch.list) -> !torch.int\n" -" torch.prim.If.yield %9 : !torch.int\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" " }\n" -" return %6 : !torch.int\n" +" return %int3 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" " %none = torch.constant.none\n" @@ -11514,6 +14065,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.l1_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -11772,10 +14341,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose1d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose3d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.convolution\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -11814,6 +14391,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.col2im\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nonzero\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " return %int4 : !torch.int\n" @@ -11841,63 +14422,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" -" %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list>\n" -" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" -" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %5 : !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int11 = torch.constant.int 11\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" -" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %4 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" -" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" -" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %8 : !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" -" %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" " %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" " %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" -" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int6 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %5 : !torch.int\n" -" }\n" -" return %7 : !torch.int\n" +" return %5 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" @@ -12494,12 +15042,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.amin\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.int {\n" +" %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple) -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.min.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.min\"(%arg0) : (!torch.tuple) -> !torch.int\n" " %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.aminmax\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.tuple {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %0#1 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mean\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %none = torch.constant.none\n" @@ -12552,7 +15109,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.float, %arg3: !torch.optional) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.prims.var\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" " return %0 : !torch.int\n" @@ -12665,6 +15222,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.binary_cross_entropy_with_logits\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %int5 = torch.constant.int 5\n" @@ -12928,6 +15503,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rot90\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -13221,6 +15800,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.list, %arg7: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list>\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list>, %arg1: !torch.int) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" @@ -13248,6 +15836,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atleast_1d\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atleast_2d\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.hstack\"(%arg0: !torch.list>) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.column_stack\"(%arg0: !torch.list>) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %1 = torch.prim.ListConstruct : () -> !torch.list\n" +" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %4, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" +" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" +" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" @@ -13371,6 +16021,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._safe_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" @@ -13451,6 +16113,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int6 = torch.constant.int 6\n" " return %int6 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.triu_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tril_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.deg2rad\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %int1 = torch.constant.int 1\n" @@ -13514,6 +16205,83 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.unfold\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"size must be less than or equal to {}\"\n" +" %false = torch.constant.bool false\n" +" %str_0 = torch.constant.str \"AssertionError: size must be less than or equal to 1\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: \"\n" +" %str_2 = torch.constant.str \"dimension out of range of {}\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %6 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n" +" %7 = torch.aten.add.str %str_1, %6 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %7, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.le.int %arg2, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" torch.prim.If.yield %5 : !torch.list\n" +" } else {\n" +" %3 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %15 = torch.aten.add.int %arg1, %0 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %15 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg1 : !torch.int\n" +" }\n" +" %5 = torch.aten.ge.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %4, %0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %15 = torch.aten.format(%str_2, %0) : !torch.str, !torch.int -> !torch.str\n" +" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.__getitem__.t %arg0, %4 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.le.int %arg2, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %15 = torch.aten.format(%str, %7) : !torch.str, !torch.int -> !torch.str\n" +" %16 = torch.aten.add.str %str_1, %15 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %16, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.sub.int %7, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.floordiv.int %9, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.add.int %10, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %12 = func.call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list) -> !torch.list\n" +" %13 = torch.aten._set_item.t %12, %4, %11 : !torch.list, !torch.int, !torch.int -> !torch.list\n" +" %14 = torch.aten.append.t %12, %arg2 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield %12 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unfold\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 6eb949e589c6..e8b0d6b0364c 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -9,13 +9,8 @@ #include "PassDetail.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -31,6 +26,15 @@ using namespace mlir::torch::Torch; using TypeBoundMap = DenseMap, Type>; namespace { + +Value materializeAsCopyTensorToType(OpBuilder &builder, + Torch::BaseTensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return copyTensorToType(builder, loc, type, inputs[0]); +} + class AdjustCallingConventionForFunc : public OpConversionPattern { public: @@ -67,7 +71,7 @@ class AdjustCallingConventionForFunc // TODO: add tuple type. conversion.addInputs(type.index(), type.value()); } - rewriter.applySignatureConversion(&func.getBody(), conversion, + rewriter.applySignatureConversion(&func.getBody().front(), conversion, typeConverter); SmallVector newResultTypes; @@ -115,7 +119,7 @@ class AdjustCallingConventionForCall continue; auto it = typeBoundMap.find({call.getCallee(), operand.index()}); if (it != typeBoundMap.end()) { - if (auto valueTensorType = it->second.dyn_cast()) { + if (auto valueTensorType = dyn_cast(it->second)) { newOperands.push_back(copyTensorToType( rewriter, call->getLoc(), valueTensorType, operand.value())); continue; @@ -203,13 +207,9 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, return success(); }); - typeConverter.addArgumentMaterialization( - [](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs, - Location loc) -> Value { - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return copyTensorToType(builder, loc, type, inputs[0]); - }); + typeConverter.addArgumentMaterialization(materializeAsCopyTensorToType); + typeConverter.addSourceMaterialization(materializeAsCopyTensorToType); + typeConverter.addTargetMaterialization(materializeAsCopyTensorToType); patterns.add(typeConverter, context); patterns.add(typeConverter, context, typeBoundMap); @@ -220,11 +220,11 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, for (int i = 0, e = func.getNumArguments(); i != e; i++) { if (func.getArgAttr(i, "torch.type_bound")) return false; - if (func.getArgumentTypes()[i].isa()) + if (isa(func.getArgumentTypes()[i])) return false; } for (int i = 0, e = func.getNumResults(); i != e; i++) { - if (func.getFunctionType().getResults()[i].isa()) + if (isa(func.getFunctionType().getResults()[i])) return false; } return true; diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index ba6af02c8e9a..1ce006fbe913 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses ReifyShapeCalculations.cpp ReifyDtypeCalculations.cpp ReifyAbstractInterpCalculationsUtils.cpp + RestructureNonConstantAxes.cpp ScalarizeShapes.cpp AbstractInterpLibrary.cpp SimplifyShapeCalculations.cpp diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9fad15e132ff..d04f9f802f3a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -9,6 +9,7 @@ #include "PassDetail.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -38,7 +39,7 @@ static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) { getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); if (failed(resDtype)) return false; - return resDtype->isa(); + return isa(*resDtype); } // Helper function to compute the return type of the reduction function. @@ -71,10 +72,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, } } - Type resultType = tensorType.getWithSizesAndDtype( + Type resultType = tensorType.getWithSizesAndDtypeAndSparsity( !tensorType.hasSizes() ? std::optional>() : llvm::ArrayRef(sizes), - tensorType.getOptionalDtype()); + tensorType.getOptionalDtype(), tensorType.getOptionalSparsity()); return resultType; } @@ -99,24 +100,39 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { Value keepDimCst = rewriter.create(loc, keepDim); - BaseTensorType valueType = - computeReductionType(rewriter, op, cast(input.getType()), - dim, keepDim) - .cast(); + BaseTensorType valueType = cast(computeReductionType( + rewriter, op, cast(input.getType()), dim, keepDim)); if (!valueType) return nullptr; BaseTensorType indexType = - valueType - .getWithSizesAndDtype( - !valueType.hasSizes() ? std::optional>() - : llvm::ArrayRef(valueType.getSizes()), - IntegerType::get(op->getContext(), 64, IntegerType::Signed)) - .cast(); + cast(valueType.getWithSizesAndDtype( + !valueType.hasSizes() ? std::optional>() + : llvm::ArrayRef(valueType.getSizes()), + IntegerType::get(op->getContext(), 64, IntegerType::Signed))); return rewriter .create(loc, valueType, indexType, input, dim, keepDimCst) .getValues(); } +// Reduction function to calculate min along given `dim`. +static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc, + Operation *op, Value input, Value dim, + bool keepDim) { + Value keepDimCst = rewriter.create(loc, keepDim); + BaseTensorType valueType = cast(computeReductionType( + rewriter, op, cast(input.getType()), dim, keepDim)); + if (!valueType) + return nullptr; + BaseTensorType indexType = + cast(valueType.getWithSizesAndDtype( + !valueType.hasSizes() ? std::optional>() + : llvm::ArrayRef(valueType.getSizes()), + IntegerType::get(op->getContext(), 64, IntegerType::Signed))); + return rewriter + .create(loc, valueType, indexType, input, dim, keepDimCst) + .getValues(); +} + // Helper for creating `aten::sub_tensor_op`. static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { @@ -277,7 +293,7 @@ static bool parseEquation(const std::string &equation, inputToken.clear(); } else if ((index < (equation.size() - 1)) && (equation.substr(index, 2).find("->") != std::string::npos)) { - inputTokens.push_back(inputToken); + inputTokens.push_back(std::move(inputToken)); inputToken.clear(); currentVariable = kIsResult; index++; @@ -286,6 +302,91 @@ static bool parseEquation(const std::string &equation, } index++; } + + if (!inputToken.empty() && currentVariable == kIsInput) { + inputTokens.push_back(std::move(inputToken)); + } + + return true; +} + +static bool +diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter, + std::string &equation, + SmallVector &inputTensors) { + SmallVector resultTokens; + SmallVector> inputTokens; + + if (!parseEquation(equation, inputTokens, resultTokens)) { + return false; + } + + for (size_t i = 0, d = inputTokens.size(); i < d; ++i) { + SmallVector inputStr = inputTokens[i]; + Value input = inputTensors[i]; + + for (size_t d0 = 0; d0 < inputStr.size(); ++d0) { + char id = inputStr[d0]; + + size_t d1; + for (d1 = d0 + 1; d1 < inputStr.size(); d1++) { + if (id == inputStr[d1]) + break; + } + + // No duplicate found so we can continue. + if (d1 == inputStr.size()) + continue; + + // Remove the ID and move to the end: + for (size_t i = d0 + 1; i < d1; ++i) + inputStr[i - 1] = inputStr[i]; + for (size_t i = d1 + 1, s = inputStr.size(); i < s; ++i) + inputStr[i - 2] = inputStr[i]; + + inputStr[inputStr.size() - 2] = id; + inputStr.resize(inputStr.size() - 1); + + auto inputTy = cast(input.getType()); + llvm::SmallVector newShape; + for (size_t i = 0, s = inputTy.getSizes().size(); i < s; ++i) { + if (i == d0 || i == d1) + continue; + newShape.push_back(inputTy.getSizes()[i]); + } + newShape.push_back(inputTy.getSizes()[d0]); + + inputTy = rewriter.getType(newShape, inputTy.getDtype()); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + Value d0Val = rewriter.create( + loc, rewriter.getI64IntegerAttr(d0)); + Value d1Val = rewriter.create( + loc, rewriter.getI64IntegerAttr(d1)); + + input = rewriter.create(loc, inputTy, /*input=*/input, + /*offset=*/zero, /*dim1=*/d0Val, + /*dim2=*/d1Val); + + // Frontmost token will have changed: + d0--; + } + + inputTokens[i] = inputStr; + inputTensors[i] = input; + } + + llvm::SmallVector inputStrings; + for (auto inputStr : inputTokens) + inputStrings.emplace_back(inputStr.begin(), inputStr.end()); + + std::string resultString(resultTokens.begin(), resultTokens.end()); + + equation = llvm::join(inputStrings, ","); + if (!resultString.empty()) + equation = equation + "->" + resultString; return true; } @@ -296,12 +397,12 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, int64_t contractingDimsLength, int64_t otherDimsLength, int64_t reduceDimsLength, bool isLhs) { - auto inputType = cast(input.getType()); + auto inputType = cast(input.getType()); auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + reduceDimsLength; - SmallVector inputShapeTensor; + SmallVector inputShapeTensor; for (auto i = 0; i < inputRank; ++i) { - inputShapeTensor.emplace_back(rewriter.create( + inputShapeTensor.emplace_back(rewriter.createOrFold( loc, input, rewriter.create(loc, rewriter.getI64IntegerAttr(i)))); @@ -312,13 +413,23 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, rewriter.create(loc, rewriter.getI64IntegerAttr(1)); auto dimOffset = 0; + auto materializeIntFold = [&](OpFoldResult thing) { + if (auto attr = dyn_cast(thing)) { + Value result = rewriter.create( + loc, cast(attr)); + return result; + } + return cast(thing); + }; + auto appendDims = [&](int64_t dimLength) { - Value prod = constOne; + OpFoldResult prod = getAsOpFoldResult(constOne); for (auto i = 0; i < dimLength; ++i) { - prod = rewriter.create(loc, prod, - inputShapeTensor[i + dimOffset]); + prod = rewriter.createOrFold( + loc, materializeIntFold(prod), + materializeIntFold(inputShapeTensor[i + dimOffset])); } - outShapeTensor.emplace_back(prod); + outShapeTensor.emplace_back(materializeIntFold(prod)); dimOffset += dimLength; }; @@ -329,12 +440,22 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, if (isLhs) appendDims(contractingDimsLength); + SmallVector resultShape; + for (auto value : outShapeTensor) { + int64_t v; + if (matchPattern(value, m_TorchConstantInt(&v))) { + resultShape.push_back(v); + continue; + } + resultShape.push_back(Torch::kUnknownSize); + } + auto outShapeValue = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), outShapeTensor); - auto outType = inputType.getWithSizesAndDtype(std::nullopt, - inputType.getOptionalDtype()); + auto outType = + inputType.getWithSizesAndDtype(resultShape, inputType.getOptionalDtype()); return rewriter.create(loc, outType, input, outShapeValue); } @@ -415,17 +536,19 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, SmallVector &contractingDims, SmallVector &otherDims, SmallVector &reduceDims, bool isLhs) { - auto inputType = cast(input.getType()); + auto inputType = cast(input.getType()); llvm::SmallDenseMap dimTokenMap; for (size_t idx = 0; idx < dimTokens.size(); ++idx) { dimTokenMap[dimTokens[idx]] = idx; } + SmallVector permuteShape; SmallVector permuteVec; auto appendDims = [&](SmallVector dimTokens) { for (auto d : dimTokens) { permuteVec.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); + permuteShape.push_back(inputType.getSizes()[dimTokenMap[d]]); } }; @@ -440,7 +563,8 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, Value dstDims = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), permuteVec); - auto outType = inputType.getWithSizesAndDtype(std::nullopt, + + auto outType = inputType.getWithSizesAndDtype(permuteShape, inputType.getOptionalDtype()); return rewriter.create(loc, outType, input, dstDims); } @@ -451,27 +575,38 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, Value &result, SmallVector &resultTokens, SmallVector &finalResultTokens) { - auto lhsType = cast(lhs.getType()); - auto rhsType = cast(rhs.getType()); + auto lhsType = cast(lhs.getType()); + auto rhsType = cast(rhs.getType()); Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() : rhsType.getOptionalDtype(); + auto materializeIntFold = [&](OpFoldResult thing) { + if (auto attr = dyn_cast(thing)) { + Value result = rewriter.create( + loc, cast(attr)); + return result; + } + return cast(thing); + }; + llvm::SmallDenseMap lhsDimShapeMap; for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { char d = lhsTokens[idx]; - lhsDimShapeMap[d] = rewriter.create( + OpFoldResult lhsFold = rewriter.createOrFold( loc, lhs, rewriter.create(loc, rewriter.getI64IntegerAttr(idx))); + lhsDimShapeMap[d] = materializeIntFold(lhsFold); } llvm::SmallDenseMap rhsDimShapeMap; for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { char d = rhsTokens[idx]; - rhsDimShapeMap[d] = rewriter.create( + OpFoldResult rhsFold = rewriter.createOrFold( loc, rhs, rewriter.create(loc, rewriter.getI64IntegerAttr(idx))); + rhsDimShapeMap[d] = materializeIntFold(rhsFold); } // parse batch, contracting, other, reduce dims of lhs and rhs @@ -491,8 +626,9 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, bool lhsContains = lhsDimShapeMap.count(d) > 0; bool rhsContains = rhsDimShapeMap.count(d) > 0; if (lhsContains && rhsContains) { - outDimShapeMap[d] = rewriter.create( + OpFoldResult out = rewriter.createOrFold( loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); + outDimShapeMap[d] = materializeIntFold(out); } else if (lhsContains) { outDimShapeMap[d] = lhsDimShapeMap[d]; } else if (rhsContains) { @@ -508,12 +644,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, generateOutDimShapeMap(lhsOtherDims); generateOutDimShapeMap(rhsOtherDims); - if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 && - rhsOtherDims.size() == 0) { - return rewriter.notifyMatchFailure( - loc, "Hadamard product is currently not supported"); - } - // shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims, contractingDims, lhsOtherDims, lhsReduceDims, @@ -531,8 +661,17 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, contractingDims.size(), rhsOtherDims.size(), rhsReduceDims.size(), false); + lhsType = cast(lhs.getType()); + rhsType = cast(rhs.getType()); + + SmallVector outShape; + outShape.push_back(lhsType.getSizes()[0]); + outShape.push_back(lhsType.getSizes()[1]); + outShape.push_back(rhsType.getSizes()[2]); + // perform matmul - auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType); + auto outType = lhsType.getWithSizesAndDtype(outShape, outputDType); + result = rewriter.create(loc, outType, lhs, rhs); // generate ideal result dims. @@ -548,11 +687,21 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, outShapeTensors.emplace_back(outDimShapeMap[d]); } + SmallVector resultShape; + for (auto value : outShapeTensors) { + int64_t v; + if (matchPattern(value, m_TorchConstantInt(&v))) { + resultShape.push_back(v); + continue; + } + resultShape.push_back(Torch::kUnknownSize); + } + auto outResultShape = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), outShapeTensors); result = rewriter.create( - loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result, + loc, lhsType.getWithSizesAndDtype(resultShape, outputDType), result, outResultShape); return success(); } @@ -609,72 +758,12 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter, return out; } -namespace { -/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the -/// number of dimensions across which the max needs to be computed. -/// Eg: -/// INPUT: -/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) -/// -/// OUTPUT: -/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 -/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 -/// final_output = aten.max.dim(input_2, 0, keepdim) #3 -/// -/// NOTE: We iterate over, in reverse order, every dimension included in `dim` -/// of the `aten.amax` op and create an `aten.amax.dim` op. -/// Input tensor to the next `aten.amax.dim` op is thus the output of the -/// previous `aten.amax.dim` op. -class DecomposeAtenAmaxOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAmaxOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - SmallVector dims; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) - - return rewriter.notifyMatchFailure(op, - "non-const dim parameter unsupported"); - - bool keepDim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) - return rewriter.notifyMatchFailure( - op, "Expected a constant boolean value for keepDim"); - - Value input = op.getSelf(); - auto inputTy = dyn_cast(input.getType()); - if (!inputTy || !inputTy.hasSizes()) { - return rewriter.notifyMatchFailure(op, - "Expected input type having sizes"); - } - // For every dimension included in `dim` of the op, iterated over in - // reverse order, we create a call to aten.max.dim. - std::sort(dims.rbegin(), dims.rend()); - for (int64_t dimInt : dims) { - int64_t inputRank = inputTy.getSizes().size(); - dimInt = toPositiveDim(dimInt, inputRank); - if (!isValidDim(dimInt, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); - // The input to the next invocation of aten.max.dim is the output of the - // previous aten.max.dim op. - input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); - } - rewriter.replaceOp(op, input); - return success(); - } -}; -} // end namespace - namespace { class DecomposeAtenTriuOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTriuOp op, PatternRewriter &rewriter) const override { - MLIRContext *context = op.getContext(); Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = cast(input.getType()); @@ -685,37 +774,50 @@ class DecomposeAtenTriuOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2"); } - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value none = rewriter.create(loc); - Value rowDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-2)); - Value colDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - Value rowSize = rewriter.create(loc, input, rowDim); - Value colSize = rewriter.create(loc, input, colDim); - - Value rowArange = rewriter.create( - loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); - Value colArange = rewriter.create( - loc, baseType, colSize, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + Value rowSize = getTensorDimSize(rewriter, input, -2); + Value colSize = getTensorDimSize(rewriter, input, -1); + + auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); + auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type); + auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type); + auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type); + + Value rowArange = + rewriter.create(loc, rowArrangeType, rowSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + Value colArange = + rewriter.create(loc, colArrangeType, colSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + auto unsqueezeRowArangeInfo = + unsqueezeTensor(rewriter, op, rowArange, cstOne); + auto unsqueezeColArangeInfo = + unsqueezeTensor(rewriter, op, colArange, cstZero); + + if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } - Value unsqueezeRowArange = - rewriter.create(loc, baseType, rowArange, cstOne); - Value unsqueezeColArange = - rewriter.create(loc, baseType, colArange, cstZero); + Value unsqueezeRowArange = unsqueezeRowArangeInfo.value(); + Value unsqueezeColArange = unsqueezeColArangeInfo.value(); Value unsqueezeRowArangePlusDiagonal = rewriter.create( - loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne); + loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(), + cstOne); + auto boolType = rewriter.getI1Type(); + auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType); Value condTensor = rewriter.create( - loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); @@ -724,6 +826,552 @@ class DecomposeAtenTriuOp : public OpRewritePattern { }; } // namespace +/* + This function calculates the number of elements in the lower triangle (below + the main diagonal) of a tensor with dimensions [row, col]. The main diagonal + can be shifted using the 'offset' parameter. The lower triangle is divided into + two parts: a trapezoid and a rectangle. The return tuple includes the number of + elements in the trapezoid, the number of elements in the rectangle, and the + index of the first row such that the element [mFirstRow, 0] is below the main + diagonal. + */ +static std::tuple +getTrilSizes(int64_t row, int64_t col, int64_t offset) { + + // Base case + if (row == 0 || col == 0) { + return std::make_tuple(0, 0, 0); + } + + // Calculate mFirstRow size + int64_t mFirstRow; + if (offset > 0) + mFirstRow = (col < offset + 1) ? col : offset + 1; + else + mFirstRow = (row + offset > 0) ? 1 : 0; + + // Calculate mLastRow size + int64_t minimum = (col < row + offset) ? col : row + offset; + int64_t mLastRow = (minimum > 0) ? minimum : 0; + + // Calculate nRowAll + minimum = (row < row + offset) ? row : row + offset; + int64_t nRowAll = (minimum > 0) ? minimum : 0; + + // Calucltae nRowTrapezoid + int64_t nRowTrapezoid = mLastRow - mFirstRow + 1; + + // Number of elements in top trapezoid - trapezoidSize + int64_t trapezoidSize = (mFirstRow + mLastRow) * nRowTrapezoid / 2; + + // Number of elements in bottom rectangle - rectangleSize + int64_t diffRow = nRowAll - nRowTrapezoid; + int64_t rectangleSize = (diffRow * col > 0) ? diffRow * col : 0; + + // Create return value + return std::make_tuple(trapezoidSize, rectangleSize, mFirstRow); +} + +/* + This function calculates the number of elements in the upper triangle (above + the main diagonal) of a tensor with dimensions [row, col]. The main diagonal + can be shifted using the 'offset' parameter. The upper triangle is divided into + two parts: a trapezoid and a rectangle. The return tuple includes the number of + elements in the trapezoid, the number of elements in the rectangle, and the + index of the first row such that the element [mFirstRow, 0] is above the main + diagonal. + */ +static std::tuple +getTriuSizes(int64_t row, int64_t col, int64_t offset) { + + // Base case + if (row == 0 || col == 0) + return std::make_tuple(0, 0, 0); + + // Calculate mFirstRow size + int64_t maximum = (col - offset > 0) ? col - offset : 0; + int64_t mFirstRow = (offset > 0) ? maximum : col; + + // Number of elements in top rectangle - calculate rectangle size + int64_t minimum = (row < -offset) ? row : -offset; + int64_t rectangleSize = (minimum * col > 0) ? minimum * col : 0; + + // Number of elements in bottom trapezoid - calculte trapezoid size + std::tuple trilSizes = + getTrilSizes(row, col, offset - 1); + int64_t trapezoidSizeTril = std::get<0>(trilSizes); + int64_t rectangleSizeTril = std::get<1>(trilSizes); + + int64_t triuSize = row * col - (trapezoidSizeTril + rectangleSizeTril); + int64_t trapezoidSize = triuSize - rectangleSize; + + // Create return value + return std::make_tuple(trapezoidSize, rectangleSize, mFirstRow); +} + +// decomposition of torch.triu_indices +// https://github.com/pytorch/pytorch/blob/67ef2683d970fc541b6d266d4b3f8ba9d13844ca/torch/_refs/__init__.py#L5829 +namespace { +class DecomposeAtenTriuIndicesOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTriuIndicesOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + // Required parameters + Value row = op.getRow(); + Value col = op.getCol(); + Value offset = op.getOffset(); + + // Check if row, col and offset are constant ints + int64_t rowInt; + if (!matchPattern(row, m_TorchConstantInt(&rowInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: row not constant int"); + + int64_t colInt; + if (!matchPattern(col, m_TorchConstantInt(&colInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: col not constant int"); + + int64_t offsetInt; + if (!matchPattern(offset, m_TorchConstantInt(&offsetInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: offset not constant int"); + + // Optional parameters + Value dtype = op.getDtype(); + Value layout = op.getLayout(); + Value device = op.getDevice(); + Value pinMemory = op.getPinMemory(); + + // Get int value for dtype + int64_t dtypeInt; + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dtype not constant int"); + + FailureOr dtypeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(dtypeType)) + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + + // Constants + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstFalse = rewriter.create(loc, false); + Value cstMinusZeroPointFive = rewriter.create( + loc, rewriter.getF64FloatAttr(-0.5)); + Value cstMinusTwoFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(-2.0)); + + // Calculte trapezoidSize, rectangleSize and mFirstRow + std::tuple triuSizes = + getTriuSizes(rowInt, colInt, offsetInt); + + int64_t trapezoidSizeInt = std::get<0>(triuSizes); + int64_t rectangleSizeInt = std::get<1>(triuSizes); + int64_t mFirstRowInt = std::get<2>(triuSizes); + + // Create const int Values from ints + Value trapezoidSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + Value rectangleSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + Value mFirstRow = rewriter.create( + loc, rewriter.getI64IntegerAttr(mFirstRowInt)); + + // Calculte column offset + Value colOffset = (offsetInt > 0) ? offset : cstZero; + + // Calculate indices for top rectangle + auto arrangeType = + getTensorTypeFromShapeValues({rectangleSize}, *dtypeType); + Value xs2 = + rewriter.create(loc, arrangeType, rectangleSize, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // Calculate row_indices2 and column_idices 2 + Value rowInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + Value colInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + + // Bottom trapezoid + auto f64DtypeInt = + getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); + arrangeType = + getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type()); + Value xs1 = + rewriter.create(loc, arrangeType, trapezoidSize, + /*dtype=*/f64DtypeInt, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // b = -0.5 - m_first_row + Value mFirstRowFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value b = rewriter.create(loc, cstMinusZeroPointFive, + mFirstRowFloat); + + // Implements this piece of code: row_inds1 = torch.floor(-b - torch.sqrt(b + // * b - 2 * xs1)) + Value bSquare = rewriter.create(loc, b, b); + + Value twoTimesXs1 = rewriter.create(loc, xs1.getType(), + xs1, cstMinusTwoFloat); + Value sqrtInput = rewriter.create( + loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + + Value sqrt = + rewriter.create(loc, sqrtInput.getType(), sqrtInput); + Value negativeSqrt = rewriter.create(loc, sqrt.getType(), sqrt); + + Value rowInds1 = rewriter.create( + loc, negativeSqrt.getType(), negativeSqrt, b, cstOne); + rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + + // Implements this piece of code: col_inds1 = torch.floor(xs1 - ((2 * + // m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + Value twoTimesMFirstRow = + rewriter.create(loc, cstTwo, mFirstRow); + twoTimesMFirstRow = + rewriter.create(loc, twoTimesMFirstRow, cstOne); + Value negativeRowInds1 = + rewriter.create(loc, rowInds1.getType(), rowInds1); + + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, twoTimesMFirstRow, + cstOne); + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, rowInds1); + negativeRowInds1 = rewriter.create( + loc, negativeRowInds1.getType(), negativeRowInds1, + cstMinusZeroPointFive); + + Value colInds1 = rewriter.create(loc, xs1.getType(), xs1, + negativeRowInds1, cstOne); + colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + + // Convert to dtype + Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); + + auto rowInds1Type = cast(rowInds1.getType()); + ArrayRef sizes = rowInds1Type.getSizes(); + Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); + rowInds1 = rewriter.create( + loc, finalRowType, rowInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + auto colInds1Type = cast(colInds1.getType()); + sizes = colInds1Type.getSizes(); + Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); + colInds1 = rewriter.create( + loc, finalColType, colInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + // Final calculation for row and col indices + if (colInt) { + + Value rectangleSizeDivCol = + rewriter.create(loc, rectangleSizeInt / colInt); + + rowInds1 = rewriter.create( + loc, rowInds1.getType(), rowInds1, rectangleSizeDivCol, cstOne); + } + + colInds1 = rewriter.create(loc, colInds1.getType(), + colInds1, colOffset, cstOne); + + Type listElemType = + cast(rowInds1.getType()) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + + Value sequenceRow = rewriter.create( + loc, listType, SmallVector{rowInds2, rowInds1}); + Value sequenceCol = rewriter.create( + loc, listType, SmallVector{colInds2, colInds1}); + + // Concatenate row and col indices + Type finalCatType = colInds1Type.getWithSizesAndDtype( + {rectangleSizeInt + trapezoidSizeInt}, int64Type); + + Value catRow = rewriter.create(loc, finalCatType, sequenceRow, + /*dim=*/cstZero); + Value catCol = rewriter.create(loc, finalCatType, sequenceCol, + /*dim=*/cstZero); + + // Make return value + Value sequence = rewriter.create( + loc, Torch::ListType::get(context, rowInds1.getType()), + ValueRange{catRow, catCol}); + Type finalStackType = colInds1Type.getWithSizesAndDtype( + ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); + + rewriter.replaceOpWithNewOp(op, finalStackType, sequence, + cstZero); + + return success(); + } +}; +} // namespace + +// decomposition of torch.tril_indices +// https://github.com/pytorch/pytorch/blob/67ef2683d970fc541b6d266d4b3f8ba9d13844ca/torch/_refs/__init__.py#L5797 +namespace { +class DecomposeAtenTrilIndicesOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTrilIndicesOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + // Required parameters + Value row = op.getRow(); + Value col = op.getCol(); + Value offset = op.getOffset(); + + // Check if row, col and offset are constant ints + int64_t rowInt; + if (!matchPattern(row, m_TorchConstantInt(&rowInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: row not constant int"); + + int64_t colInt; + if (!matchPattern(col, m_TorchConstantInt(&colInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: col not constant int"); + + int64_t offsetInt; + if (!matchPattern(offset, m_TorchConstantInt(&offsetInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: offset not constant int"); + + // Optional parameters + Value dtype = op.getDtype(); + Value layout = op.getLayout(); + Value device = op.getDevice(); + Value pinMemory = op.getPinMemory(); + + // Constants + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstTwo = rewriter.create(loc, 2); + Value cstFalse = rewriter.create(loc, false); + Value cstZeroPointFive = rewriter.create( + loc, rewriter.getF64FloatAttr(0.5)); + Value cstTwoFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); + + // Get int value for dtype + int64_t dtypeInt; + if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dtype not constant int"); + + FailureOr dtypeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(dtypeType)) + return rewriter.notifyMatchFailure(op, "dtype is undefined"); + + // Calculte trapezoidSize, rectangleSize and mFirstRow + std::tuple triuSizes = + getTrilSizes(rowInt, colInt, offsetInt); + + int64_t trapezoidSizeInt = std::get<0>(triuSizes); + int64_t rectangleSizeInt = std::get<1>(triuSizes); + int64_t mFirstRowInt = std::get<2>(triuSizes); + + // Create const int Values from ints + Value trapezoidSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + Value rectangleSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + Value mFirstRow = rewriter.create( + loc, rewriter.getI64IntegerAttr(mFirstRowInt)); + + // Calculte column offset + int64_t rowOffsetInt = (-offsetInt > 0) ? (-offsetInt) : 0; + Value rowOffset = rewriter.create(loc, rowOffsetInt); + + // First we do the indices for TOP trapezoid + auto f64DtypeInt = + getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); + auto arrangeType = + getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type()); + Value xs1 = + rewriter.create(loc, arrangeType, trapezoidSize, + /*dtype=*/f64DtypeInt, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // b = m_first_row - 0.5 + Value mFirstRowFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value b = + rewriter.create(loc, mFirstRowFloat, cstZeroPointFive); + + // Implements this piece of code: row_inds1 = torch.floor(-b + torch.sqrt(b + // * b + 2 * xs1)) + Value bSquare = rewriter.create(loc, b, b); + + Value twoTimesXs1 = + rewriter.create(loc, xs1.getType(), xs1, cstTwoFloat); + Value sqrtInput = rewriter.create( + loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + + Value sqrt = + rewriter.create(loc, sqrtInput.getType(), sqrtInput); + + Value rowInds1 = + rewriter.create(loc, sqrt.getType(), sqrt, b, cstOne); + rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + + // Implements this piece of code: col_inds1 = torch.floor(xs1 - (2 * + // m_first_row - 1 + row_inds1) * row_inds1 * 0.5) + Value twoTimesMFirstRow = + rewriter.create(loc, cstTwo, mFirstRow); + twoTimesMFirstRow = + rewriter.create(loc, twoTimesMFirstRow, cstOne); + twoTimesMFirstRow = rewriter.create( + loc, rowInds1.getType(), rowInds1, twoTimesMFirstRow, cstOne); + twoTimesMFirstRow = rewriter.create( + loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, rowInds1); + twoTimesMFirstRow = rewriter.create( + loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, cstZeroPointFive); + + Value colInds1 = rewriter.create( + loc, xs1.getType(), xs1, twoTimesMFirstRow, cstOne); + colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + + // Convert top trapezoid indices to dtype + Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); + + auto rowInds1Type = cast(rowInds1.getType()); + ArrayRef sizes = rowInds1Type.getSizes(); + Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); + rowInds1 = rewriter.create(loc, rowInds1.getType(), + rowInds1, rowOffset, cstOne); + rowInds1 = rewriter.create( + loc, finalRowType, rowInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + auto colInds1Type = cast(colInds1.getType()); + sizes = colInds1Type.getSizes(); + Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); + colInds1 = rewriter.create( + loc, finalColType, colInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); + + // Calculate indices for BOTTOM rectangle + arrangeType = getTensorTypeFromShapeValues({rectangleSize}, *dtypeType); + Value xs2 = + rewriter.create(loc, arrangeType, rectangleSize, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); + + // Implements this line of code: row_inds2 = xs2 // col + (col - m_first_row + // + 1 + row_offset) + Value rowInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + int64_t addInt = colInt - mFirstRowInt + 1 + rowOffsetInt; + Value cstAdd = rewriter.create(loc, addInt); + rowInds2 = rewriter.create(loc, rowInds2.getType(), + rowInds2, cstAdd, cstOne); + + // Implements this line of code: col_inds2 = xs2 % col + Value colInds2 = + rewriter.create(loc, xs2.getType(), xs2, col); + + // Prepare tensors for concatenation + Type listElemType = + cast(rowInds1.getType()) + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + + Value sequenceRow = rewriter.create( + loc, listType, SmallVector{rowInds1, rowInds2}); + Value sequenceCol = rewriter.create( + loc, listType, SmallVector{colInds1, colInds2}); + + // Concatenate row and col indices + Type finalCatType = colInds1Type.getWithSizesAndDtype( + {rectangleSizeInt + trapezoidSizeInt}, int64Type); + + Value catRow = rewriter.create(loc, finalCatType, sequenceRow, + /*dim=*/cstZero); + Value catCol = rewriter.create(loc, finalCatType, sequenceCol, + /*dim=*/cstZero); + + // Make return value - stack row and col indices + Value sequence = rewriter.create( + loc, Torch::ListType::get(context, rowInds1.getType()), + ValueRange{catRow, catCol}); + Type finalStackType = colInds1Type.getWithSizesAndDtype( + ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); + + rewriter.replaceOpWithNewOp(op, finalStackType, sequence, + cstZero); + + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenDeg2radOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenDeg2radOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.getDtype()) { + return rewriter.notifyMatchFailure(op, "requires tensor types input."); + } + + auto outTy = dyn_cast(op.getType()); + if (!outTy || !outTy.getDtype()) { + return rewriter.notifyMatchFailure( + op, "requires output is a tensor with dtype."); + } + + if (selfTy.getDtype() != outTy.getDtype()) { + self = convertTensorToDtype(rewriter, loc, self, outTy.getDtype()); + } + + Value pi = + rewriter.create(loc, rewriter.getF64FloatAttr(M_PI)); + Value basic = + rewriter.create(loc, rewriter.getF64FloatAttr(180.0)); + Value rad = + rewriter.create(loc, op.getType(), self, basic); + Value result = rewriter.create(loc, op.getType(), rad, pi); + + rewriter.replaceOp(op, result); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -850,18 +1498,6 @@ class DecomposePrimTolistOp : public OpRewritePattern { }; } // namespace -namespace { -class DecomposeAtenSplitSizesOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSplitSizesOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenSplitWithSizesOp : public OpRewritePattern { @@ -888,7 +1524,7 @@ class DecomposeAtenSplitWithSizesOp auto sliceTy = dyn_cast_or_null(resultTy.getContainedType()); - if (!isa(sliceTy)) + if (!sliceTy || !sliceTy.hasSizes()) return rewriter.notifyMatchFailure(op, "Slice type is unknown"); int64_t dimInt = 0; @@ -1059,7 +1695,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenEyeMOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto outType = op.getType().dyn_cast(); + auto outType = dyn_cast(op.getType()); if (!outType) return rewriter.notifyMatchFailure( op, "Only tensor types input are currently supported"); @@ -1221,22 +1857,107 @@ class DecomposeAtenReshapeOp : public OpRewritePattern { } // namespace namespace { -// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce -// operation and permute operation. Currently, this pass doesn't support -// Hadamard product. The basic idea is that: -// Step 1: split the string equation to input/result tokens and find -// batchingDims, contractingDims, otherDims and reduceDims. -// Step 2: permute and reshape input tensors suitable -// for matmul operations. -// Step 3: use AtenMatmulOp to get the result. -// Step 4: iteratively execute step 2 & 3 until we get the final result. -// Step 5: perform remaining permute and reduce operations. -// notice: support static shape only - -class DecomposeAtenEinsumOp : public OpRewritePattern { +// Decompose aten.atleast_1d into: aten.reshape. See +// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2591 +// def atleast_1d( +// arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: +// TensorLikeType +// ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: +// """Refrence implementation of :func:`torch.atleast_1d`.""" +// if not args and isinstance(arg, collections.abc.Sequence): +// args_ = arg +// else: +// assert not isinstance(arg, collections.abc.Sequence) +// args_ = (arg,) + args +// res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) +// return res if len(res) > 1 else res[0] +class DecomposeAtenAtleast1dOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenEinsumOp op, + LogicalResult matchAndRewrite(AtenAtleast1dOp op, + PatternRewriter &rewriter) const override { + Value input = op.getSelf(); + Location loc = op.getLoc(); + Type opType = op.getType(); + auto inpType = cast(input.getType()); + SmallVector inputShape(inpType.getSizes()); + if (inputShape.empty()) { + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, opType, input, zero); + return success(); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; +} // namespace + +namespace { +// Decompose aten.atleast_2d into: aten.reshape. See +// https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604 +// def atleast_2d( +// arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: +// TensorLikeType +// ) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]: +// """Reference implementation of :func:`torch.atleast_2d`.""" +// if not args and isinstance(arg, collections.abc.Sequence): +// args_ = arg +// else: +// assert not isinstance(arg, collections.abc.Sequence) +// args_ = (arg,) + args +// unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) +// res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) +// return res if len(res) > 1 else res[0] +class DecomposeAtenAtleast2dOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAtleast2dOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + Type opType = op.getType(); + + auto inputType = cast(input.getType()); + SmallVector inputShape(inputType.getSizes()); + + if (inputShape.size() >= 2) { + rewriter.replaceOp(op, input); + return success(); + } + auto atleast1dResShape = + inputShape.empty() ? SmallVector{1} : inputShape; + auto atleast1dResType = rewriter.getType( + atleast1dResShape, inputType.getOptionalDtype()); + auto atleast1dRes = + rewriter.create(loc, atleast1dResType, input); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp(op, opType, atleast1dRes, + zero); + return success(); + } +}; +} // namespace + +namespace { +// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce +// operation and permute operation. Currently, this pass doesn't support +// Hadamard product. The basic idea is that: +// Step 1: split the string equation to input/result tokens and find +// batchingDims, contractingDims, otherDims and reduceDims. +// Step 2: permute and reshape input tensors suitable +// for matmul operations. +// Step 3: use AtenMatmulOp to get the result. +// Step 4: iteratively execute step 2 & 3 until we get the final result. +// Step 5: perform remaining permute and reduce operations. +// notice: support static shape only + +class DecomposeAtenEinsumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEinsumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -1275,6 +1996,13 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { op, "Unexpected character in equations encountered"); } } + + if (!diagonalizeInputAndRewriteEquation(op.getLoc(), rewriter, equation, + inputTensors)) { + return rewriter.notifyMatchFailure(op, + "Failed to handle diagonalization"); + } + SmallVector resultTokens; SmallVector> inputTokens; if (!parseEquation(equation, inputTokens, resultTokens)) { @@ -1306,6 +2034,125 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { }; } // namespace +namespace { +// Trilinear einstein sum, decomposed to: +// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3)) +// .sum(sumdim) +// The unrollDim operand does not impact the output of the operation, so +// it is ignored. + +class DecomposeAten_TrilinearOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_TrilinearOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + + Value input1 = op.getI1(); + Value input2 = op.getI2(); + Value input3 = op.getI3(); + + // Expansions + SmallVector expand1; + SmallVector expand2; + SmallVector expand3; + if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) { + return rewriter.notifyMatchFailure(op, "expand1 should be constant"); + } + if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) { + return rewriter.notifyMatchFailure(op, "expand2 should be constant"); + } + if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) { + return rewriter.notifyMatchFailure(op, "expand3 should be constant"); + } + + SmallVector sumDim; + if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) { + return rewriter.notifyMatchFailure(op, "sumDim should be constant"); + } + + // Check if there are any dimensions that intersect between expand1, + // expand2, and expand3. + int64_t totalDims = + cast(input1.getType()).getSizes().size() + + expand1.size(); + if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) { + // pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353 + // TODO: Remove warning when issue gets resolved. + op->emitWarning("aten::_trilinear implementation in this case is " + "non-functional (returns an empty dimension). We will " + "intentionally deviate from this behavior."); + } + + // Apply unsqueeze to respective input tensors at the specified dimensions + SmallVector sortedExpand1 = expand1; + std::sort(sortedExpand1.begin(), sortedExpand1.end()); + for (auto expand : sortedExpand1) { + Value expandDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(expand)); + input1 = *unsqueezeTensor(rewriter, op, input1, expandDim); + } + SmallVector sortedExpand2 = expand2; + std::sort(sortedExpand2.begin(), sortedExpand2.end()); + for (auto expand : sortedExpand2) { + Value expandDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(expand)); + input2 = *unsqueezeTensor(rewriter, op, input2, expandDim); + } + SmallVector sortedExpand3 = expand3; + std::sort(sortedExpand3.begin(), sortedExpand3.end()); + for (auto expand : sortedExpand3) { + Value expandDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(expand)); + input3 = *unsqueezeTensor(rewriter, op, input3, expandDim); + } + + // Apply multiplication operation. + auto mul1 = + rewriter.create(loc, op.getType(), input1, input2); + auto mul2 = + rewriter.create(loc, op.getType(), mul1, input3); + + // Apply sum operation. + // Parse sumDim in descending order to avoid any issues with the + // dimensions being removed. + Value result = mul2; + SmallVector sortedSumDims = sumDim; + std::sort(sortedSumDims.rbegin(), sortedSumDims.rend()); + for (int64_t dim : sortedSumDims) { + Value dimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + result = + createSumAlongDimension(rewriter, loc, op, result, dimValue, false); + } + + rewriter.replaceOp(op, result); + return success(); + } + +private: + // Determine if there are any dimensions that intersect between expand1, + // expand2, and expand3. + bool sharedExpandDims(const int64_t &totalDims, + const SmallVector &expand1, + const SmallVector &expand2, + const SmallVector &expand3, + const SmallVector &sumDim) const { + for (int64_t i = 0; i < totalDims; ++i) { + if (!contains(sumDim, i) && contains(expand1, i) && + contains(expand2, i) && contains(expand3, i)) { + return true; + } + } + return false; + } + bool contains(const SmallVector &vec, int64_t value) const { + return std::find(vec.begin(), vec.end(), value) != vec.end(); + } +}; +} // namespace + namespace { // Calculate the trace of the input tensor as the sum over its diagonal // elements. This computation is performed as: @@ -1481,6 +2328,62 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { }; } // namespace +// Ref: +// https://github.com/pytorch/pytorch/blob/5314ae2660a778b87987030182f787bb6cb092c0/aten/src/ATen/native/transformers/attention.cpp#L663-L673 +namespace { +class DecomposeAten_SafeSoftmaxOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_SafeSoftmaxOp op, + PatternRewriter &rewriter) const override { + BaseTensorType resultTensorType = cast(op.getType()); + if (!resultTensorType.hasDtype() || !resultTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "expected result type to have sizes and dtype"); + } + SmallVector sizes(resultTensorType.getSizes()); + + int64_t dimInt; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure(op, "Unsupported: non-constant dim"); + + dimInt = toPositiveDim(dimInt, sizes.size()); + if (!isValidDim(dimInt, sizes.size())) + return rewriter.notifyMatchFailure(op, "dim int is not valid"); + + Location loc = op.getLoc(); + Value softmax = rewriter.create( + loc, op.getType(), op.getSelf(), op.getDim(), op.getDtype()); + + Type resultTensorDtype = resultTensorType.getDtype(); + + Value negInfinity = getConstantWithGivenDtypeAndValue( + rewriter, loc, -std::numeric_limits::infinity(), + resultTensorDtype); + + auto boolDtype = rewriter.getI1Type(); + auto boolTensorType = + resultTensorType.getWithSizesAndDtype(sizes, boolDtype); + Value masked = rewriter.create(loc, boolTensorType, + op.getSelf(), negInfinity); + + sizes[dimInt] = 1; + auto maskedRowsType = + resultTensorType.getWithSizesAndDtype(sizes, boolDtype); + Value cstTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value maskedRows = rewriter.create( + loc, maskedRowsType, masked, op.getDim(), cstTrue); + Value cstZero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0.0, + resultTensorDtype); + rewriter.replaceOpWithNewOp( + op, resultTensorType, maskedRows, cstZero, softmax); + return success(); + } +}; +} // namespace + // Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) => // newGrad = gradOutput * output // result = newGrad - output * sum(newGrad, dim)) @@ -1584,51 +2487,69 @@ class DecomposeAten_LogSoftmaxBackwardDataOp } // namespace namespace { -class DecomposeAtenAMinMaxOp : public OpRewritePattern { +/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the +/// number of dimensions across which the max needs to be computed. +/// Eg: +/// INPUT: +/// final_output = aten.amax(initial_input, dim=(0, 2, 1), keepdim=False) +/// +/// OUTPUT: +/// input_1 = aten.max.dim(initial_input, 2, keepdim) #1 +/// input_2 = aten.max.dim(input_1, 1, keepdim) #2 +/// final_output = aten.max.dim(input_2, 0, keepdim) #3 +/// +/// NOTE: We iterate over, in reverse order, every dimension included in `dim` +/// of the `aten.amax` op and create an `aten.amax.dim` op. +/// Input tensor to the next `aten.amax.dim` op is thus the output of the +/// previous `aten.amax.dim` op. +template +class DecomposeAtenAminAmaxOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Torch::AtenAminOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - llvm::SmallVector dimList; - if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { - return rewriter.notifyMatchFailure(op, "dims not foldable constants"); - } + Location loc = op.getLoc(); + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for keepDim"); - bool keepdim; - if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { - return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); + Value input = op.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "Expected input type having sizes"); } - auto loc = op.getLoc(); - std::sort(dimList.begin(), dimList.end(), std::greater()); - - Value reduction = op.getSelf(); - auto resultTy = cast(op.getType()); - auto reductionTy = cast(reduction.getType()); - llvm::SmallVector reductionShape(reductionTy.getSizes()); + SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dims))) + return rewriter.notifyMatchFailure(op, + "non-const dim parameter unsupported"); + if (dims.size() == 0) { + dims = llvm::to_vector(llvm::seq(0, inputTy.getSizes().size())); + } - for (auto dim : dimList) { - auto dimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim)); - reductionShape[dim] = 1; - if (!keepdim) { - for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) - reductionShape[i] = reductionShape[i + 1]; - reductionShape.resize(reductionShape.size() - 1); + // For every dimension included in `dim` of the op, iterated over in + // reverse order, we create a call to aten.max.dim. + std::sort(dims.rbegin(), dims.rend()); + for (int64_t dimInt : dims) { + int64_t inputRank = inputTy.getSizes().size(); + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimInt)); + // The input to the next invocation of aten.max.dim is the output of the + // previous aten.max.dim op. + static_assert(std::is_same_v || + std::is_same_v); + if (std::is_same_v) { + input = createMaxAlongDimension(rewriter, loc, op, input, dim, keepDim); + } else if (std::is_same_v) { + input = createMinAlongDimension(rewriter, loc, op, input, dim, keepDim); } - - reductionTy = rewriter.getType( - reductionShape, resultTy.getOptionalDtype()); - auto idxTy = rewriter.getType( - reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); - llvm::SmallVector types{reductionTy, idxTy}; - reduction = rewriter - .create(loc, types, reduction, - dimValue, op.getKeepdim()) - .getResult(0); } - - rewriter.replaceOp(op, reduction); + rewriter.replaceOp(op, input); return success(); } }; @@ -1646,49 +2567,113 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getSelf(); Value dim = op.getDim(); - Value keepDim = op.getKeepdim(); Value result = op.getResult(); + bool keepDim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure( + op, "expected keepdim to be a constant bool"); + } BaseTensorType inputType = cast(input.getType()); BaseTensorType indicesTensorType = cast(result.getType()); std::optional maybeInputRank = getTensorRank(input); - if (!maybeInputRank) { + if (!maybeInputRank || *maybeInputRank == 0) { return rewriter.notifyMatchFailure( - op, "expected input tensor to have a rank"); + op, "expected input tensor to have a rank > 0"); } unsigned inputRank = *maybeInputRank; if (!indicesTensorType.hasSizes()) return failure(); - BaseTensorType valueTensorType = - inputType - .getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), - inputType.getOptionalDtype()) - .cast(); + BaseTensorType valueTensorType = cast( + inputType.getWithSizesAndDtype(indicesTensorType.getOptionalSizes(), + inputType.getOptionalDtype())); // If the dim type is `NoneType` i.e. reduce along all the dimensions. // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so // first the input tensor is flattened to 1d tensor and then the reduction // happens on the 0th dimension. if (isa(dim.getType())) { - BaseTensorType flattenType = - inputType - .getWithSizesAndDtype({kUnknownSize}, - inputType.getOptionalDtype()) - .cast(); - dim = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); - input = rewriter.create(loc, flattenType, input, - dim, end); + Value zero = rewriter.create(loc, 0); + Value falseValue = rewriter.create(loc, false); + if (inputType.getSizes().size() > 1) { + int64_t flattenSize = Torch::kUnknownSize; + if (inputType.areAllSizesKnown()) { + flattenSize = 1; + for (int64_t sze : inputType.getSizes()) + flattenSize *= sze; + } + auto flattenType = cast(inputType.getWithSizesAndDtype( + {flattenSize}, inputType.getOptionalDtype())); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); + input = rewriter.create(loc, flattenType, input, + zero, end); + } + Value resultIndices = + rewriter + .create( + loc, + valueTensorType.getWithSizesAndDtype( + ArrayRef{}, valueTensorType.getOptionalDtype()), + indicesTensorType.getWithSizesAndDtype( + ArrayRef{}, + indicesTensorType.getOptionalDtype()), + input, /*dim=*/zero, /*keepdim=*/falseValue) + .getIndices(); + if (keepDim) { + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value dimList = rewriter.create( + loc, + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + SmallVector(inputRank, one)); + resultIndices = rewriter.create( + loc, + indicesTensorType.getWithSizesAndDtype( + SmallVector(inputRank, 1), + indicesTensorType.getOptionalDtype()), + resultIndices, dimList); + } + rewriter.replaceOp(op, resultIndices); + return success(); + } else { + Value resultIndices = + rewriter + .create(loc, valueTensorType, indicesTensorType, + input, dim, op.getKeepdim()) + .getIndices(); + rewriter.replaceOp(op, resultIndices); + return success(); } + } +}; +} // namespace - Value resultArg = - rewriter - .create(loc, valueTensorType, indicesTensorType, input, - dim, keepDim) - .getIndices(); +// Decompose `AtenAminmaxOp` to `AtenAminOp` + `AtenAmaxOp` +namespace { +class DecomposeAtenAminmaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAminmaxOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Torch::ListType listType = + rewriter.getType(rewriter.getType()); + Value dimList; + if (isa(op.getDim().getType())) { + dimList = rewriter.create(loc, listType, + ArrayRef{}); + } else { + dimList = rewriter.create( + loc, listType, ArrayRef{op.getDim()}); + } - rewriter.replaceOp(op, resultArg); + auto amin = rewriter.create( + loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim()); + auto amax = rewriter.create( + loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim()); + rewriter.replaceOp(op, {amin, amax}); return success(); } }; @@ -2076,6 +3061,145 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace +// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229 +// Decompose aten.renorm into: linalg_vector_norm +namespace { +class DecomposeAtenRenormOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRenormOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value dim = op.getDim(); + Value p = op.getP(); + Value maxnorm = op.getMaxnorm(); + + // Prepare all necessary variables + auto ndim = getTensorRank(self); + auto resType = cast(self.getType()); + + if (!resType.hasDtype() || !resType.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "result should have dtype and sizes"); + } + + Type dtype = resType.getDtype(); + if (isa(dtype)) { + return rewriter.notifyMatchFailure( + op, "lowering of aten.renorm for complex inputs dtype is " + "currently unimplemented"); + } + + SmallVector inputSize(resType.getSizes()); + + // Convert dim from Value to int + int64_t dimInt; + if (!matchPattern(dim, m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: dim not constant int"); + + // Define all constants + Value cstTrue = rewriter.create(loc, true); + Value cstZero = rewriter.create(loc, 0); + Value cstOne = rewriter.create(loc, 1); + Value cstNone = rewriter.create(loc); + + // Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... , + // ndim-1] + llvm::SmallVector reduceDimsVector; + for (uint64_t i = 0; i < ndim; i++) { + if (i == (uint64_t)dimInt) + continue; + + Value constI = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + + reduceDimsVector.push_back(constI); + } + + Value reduceDimsList = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + reduceDimsVector); + + // Make output shape for linalg.vector_norm operation + SmallVector inputSizeValue; + for (uint64_t i = 0; i < inputSize.size(); i++) { + if (i != (uint64_t)dimInt) + inputSize[i] = 1; + + inputSizeValue.push_back( + rewriter.create(loc, inputSize[i])); + } + + // Prepare arguments for linalg.vector_norm + Value dtypeValue; + Type vectorNormOutType; + + if (isa(dtype)) { + dtype = cast(rewriter.getF32Type()); + dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype); + vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype); + } else { + dtypeValue = cstNone; + vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype); + } + + auto norm = rewriter.create( + loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue); + + // Define epsiolon constant 10^-7 + mlir::FloatType f64Type = rewriter.getF64Type(); + Value epsValue = rewriter.create( + loc, rewriter.getFloatAttr(f64Type, 1e-7)); + + Value normPlusEps = rewriter.create( + loc, vectorNormOutType, norm, epsValue, cstOne); + + Value maxnormTensorValue = rewriter.create( + loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone, + cstNone, cstNone, cstNone); + + // Divide maxnorm and normPlusEps + auto divideMaxnormAndNorm = rewriter.create( + loc, vectorNormOutType, maxnormTensorValue, normPlusEps); + + // Next few lines corespond to this pythorch code: norm_factor = + // torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) + auto boolTensorType = rewriter.getType( + cast(vectorNormOutType).getOptionalSizes(), + rewriter.getI1Type()); + + Value greaterThanMaxnorm = + rewriter.create(loc, boolTensorType, norm, maxnorm); + + Value cstOnetensor = rewriter.create( + loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone, + cstNone, cstNone, cstNone); + + auto normFactor = rewriter.create( + loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm, + cstOnetensor); + + // Converte norm_factor to input dtype + Value normFactorFinal = rewriter.create( + loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()), + normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype())); + + // Multiply input tensor with norm factor + auto output = rewriter.create(loc, self.getType(), self, + normFactorFinal); + + rewriter.replaceOpWithNewOp(op, self.getType(), output, + /*memory_format*/ cstZero); + + return success(); + } +}; +} // namespace + // Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select, // aten.add.Tensor and aten.mull.Tensor. See // https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70. @@ -2187,27 +3311,78 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { }; } // namespace -// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and -// prims.collapse operations. -// -// If input is a tensor of shape -// (*leading_dims, C*r*r, H, W), -// -// where leading_dims is of size N, then -// X = pixel_shuffle(input, upscale_factor) -// -// gets replaced with -// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W) -// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W) -// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2) -// # shape (*leading_dims, C, H, r, W, r) -// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W) -// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) -// -// 'r' above is referred to as the 'upscale factor' or just 'factor' below. +// decompose aten.linalg_slogdet into: aten.sgn, aten.log, aten.abs +// aten.linalg_det namespace { -class DecomposeAtenPixelShuffleOp - : public OpRewritePattern { + +class DecomposeAtenLinalgSlogdetOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgSlogdetOp op, + PatternRewriter &rewriter) const override { + SmallVector results = op.getResults(); + Location loc = op.getLoc(); + Value input = op.getA(); + Value determinant = rewriter.create( + loc, results[0].getType(), input); + Value sign = + rewriter.create(loc, determinant.getType(), determinant); + Value abs_det = + rewriter.create(loc, determinant.getType(), determinant); + Value ln_abs_det = + rewriter.create(loc, abs_det.getType(), abs_det); + rewriter.replaceAllUsesWith(results[0], sign); + rewriter.replaceAllUsesWith(results[1], ln_abs_det); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +namespace { + +class DecomposeAten_LinalgDetOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_LinalgDetOp op, + PatternRewriter &rewriter) const override { + SmallVector results = op.getResults(); + if (!results[1].use_empty() || !results[2].use_empty()) + return rewriter.notifyMatchFailure( + op, "unsupported: _linalg_det results: LU and pivot"); + Location loc = op.getLoc(); + Value input = op.getA(); + Value determinant = rewriter.create( + loc, results[0].getType(), input); + rewriter.replaceAllUsesWith(results[0], determinant); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. +// +// If input is a tensor of shape +// (*leading_dims, C*r*r, H, W), +// +// where leading_dims is of size N, then +// X = pixel_shuffle(input, upscale_factor) +// +// gets replaced with +// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W) +// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W) +// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2) +// # shape (*leading_dims, C, H, r, W, r) +// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W) +// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) +// +// 'r' above is referred to as the 'upscale factor' or just 'factor' below. +namespace { +class DecomposeAtenPixelShuffleOp + : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPixelShuffleOp op, @@ -2500,6 +3675,59 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace +namespace { +class DecomposeAtenRreluWithNoiseBackwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseBackwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradOutput = op.getGradOutput(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, + "training should be a bool constant"); + } + + bool selfIsResult = false; + if (!matchPattern(op.getSelfIsResult(), + m_TorchConstantBool(&selfIsResult)) || + selfIsResult) + return rewriter.notifyMatchFailure( + op, "unimplemented: self_is_result should be false"); + + double lower, upper; + if (!matchPattern(op.getLower(), m_TorchConstantFloat(&lower)) || + !matchPattern(op.getUpper(), m_TorchConstantFloat(&upper))) { + return rewriter.notifyMatchFailure( + op, "lower and upper should be float constants"); + } + + if (training && (upper - lower > 0.000001)) { + Value rreluWithNoiseBackwardOutput = + rewriter.create(loc, resType, gradOutput, noise); + rewriter.replaceOp(op, rreluWithNoiseBackwardOutput); + } else { + double negative_slope = (upper + lower) / 2; + Value cstNegativeSlope = rewriter.create( + loc, rewriter.getF64FloatAttr(negative_slope)); + rewriter.replaceOpWithNewOp( + op, resType, gradOutput, self, cstNegativeSlope, + op.getSelfIsResult()); + } + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenPreluOp : public OpRewritePattern { public: @@ -2528,6 +3756,176 @@ class DecomposeAtenPreluOp : public OpRewritePattern { } // namespace +// rrelu = max(0, x) + min(0, alpha * x) +// if in training mode, the alpha is sampled from uniform distribution (lower, +// upper) if in testing mode, the alpha is (lower + upper) / 2 +namespace { +class DecomposeAtenRreluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + + Value alpha; + if (training) { + // Create a uniform random op with low and high set to `lower` and + // `upper`, respectively. + Value none = rewriter.create(loc); + alpha = rewriter.create(loc, resType, self, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); + } + + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + + Value scaledSelf; + if (training) { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } else { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } + + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, rreluOutput); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenRreluWithNoiseOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getType()); + Value cstNone = rewriter.create(loc); + Value cstFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + Value result = + rewriter + .create( + loc, resType, self, noise, lower, upper, cstFalse, cstNone) + ->getResult(0); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenRreluWithNoiseFunctionalOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRreluWithNoiseFunctionalOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value noise = op.getNoise(); + Value lower = op.getLower(); + Value upper = op.getUpper(); + auto resType = cast(op.getResultTypes()[0]); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + bool training; + if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) { + return rewriter.notifyMatchFailure(op, "training should be a constant"); + } + + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value constantOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value constantTwoFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + + Value alpha; + if (training) { + Value none = rewriter.create(loc); + Value emptyTensor = rewriter.create( + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); + } else { + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); + } + + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, self); + + Value scaledSelf; + if (training) { + scaledSelf = rewriter.create(loc, resType, self, alpha); + auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(), + rewriter.getI1Type()); + Value oneTensor = + createRank0Tensor(rewriter, loc, resType, constantOneFloat); + Value not_positive = rewriter.create( + loc, boolResType, self, constantZeroFloat); + noise = rewriter.create(loc, resType, not_positive, + alpha, oneTensor); + } else { + scaledSelf = rewriter.create(loc, resType, self, alpha); + } + + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOneFloat); + rewriter.replaceOp(op, {rreluOutput, noise}); + return success(); + } +}; +} // namespace + // CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1)) namespace { class DecomposeAtenCeluOp : public OpRewritePattern { @@ -2593,7 +3991,36 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { auto weightedDelta = rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, inputType, start, + auto lerp = rewriter.create(loc, resType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenLerpTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = cast(start.getType()); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, resType, start, weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); @@ -2780,6 +4207,120 @@ class DecomposeAtenStackOp : public OpRewritePattern { }; } // namespace +// Decompose `aten.hstack` into `aten.at_least1d` and `aten.cat`. +// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L3908 +namespace { +class DecomposeAtenHstackOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHstackOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + // Get SmallVector from Value. + SmallVector tensors; + if (!getListConstructElements(op.getTensors(), tensors)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the tensor list is not from list construct"); + + // Execute AtenAtleast1dOp on every tensor inside tensors. + SmallVector atleast1dTensors; + for (auto tensor : tensors) { + std::optional tensorRank = getTensorRank(tensor); + + // Check if the tensor is already of rank >= 1. + if (*tensorRank < 1) { + auto atleast1dTensor = + rewriter.create(loc, tensor.getType(), tensor); + atleast1dTensors.push_back(atleast1dTensor); + } else { + atleast1dTensors.push_back(tensor); + } + } + + // Make Value list from atleast1dTensors variable. + auto elemType = cast(atleast1dTensors[0].getType()) + .getWithSizesAndDtype(std::nullopt, nullptr); + Value atleast1dTensorList = rewriter.create( + loc, Torch::ListType::get(elemType), atleast1dTensors); + + // Replace hstack with cat operator. + if (getTensorRank(atleast1dTensors[0]) == 1) + rewriter.replaceOpWithNewOp( + op, op.getType(), atleast1dTensorList, + rewriter.create(loc, rewriter.getI64IntegerAttr(0))); + else + rewriter.replaceOpWithNewOp( + op, op.getType(), atleast1dTensorList, + rewriter.create(loc, rewriter.getI64IntegerAttr(1))); + + return success(); + } +}; +} // namespace + +// Decompose `aten.column_stack` into `aten.reshape` and `aten.cat`. +// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L2822 +namespace { +class DecomposeAtenColumnStackOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenColumnStackOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + SmallVector tensors; + if (!getListConstructElements(op.getTensors(), tensors)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the tensor list is not from list construct"); + + for (auto tensor : tensors) { + auto tTy = dyn_cast(tensor.getType()); + if (!tTy || !tTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "unimplemented: one tensor does not have known sizes"); + } + + SmallVector tensors2d; + for (auto tensor : tensors) { + auto tTy = dyn_cast(tensor.getType()); + SmallVector tSizes(tTy.getSizes()); + if (tSizes.size() <= 1) { + if (tSizes.size() == 0) { + tSizes.push_back(1); + } + tSizes.push_back(1); + auto newTy = tTy.getWithSizesAndDtype(tSizes, tTy.getDtype()); + SmallVector newShapeList; + for (auto tSize : tSizes) { + newShapeList.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(tSize))); + } + auto newShape = rewriter.create( + loc, Torch::ListType::get(rewriter.getType()), + newShapeList); + Value tensor2d = + rewriter.create(loc, newTy, tensor, newShape); + tensors2d.push_back(tensor2d); + } else { + tensors2d.push_back(tensor); + } + } + + auto elemType = cast(tensors2d[0].getType()) + .getWithSizesAndDtype(std::nullopt, nullptr); + Value newTensors = rewriter.create( + loc, Torch::ListType::get(elemType), tensors2d); + + rewriter.replaceOpWithNewOp( + op, op.getType(), newTensors, + rewriter.create(loc, rewriter.getI64IntegerAttr(1))); + + return success(); + } +}; +} // namespace + // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { @@ -3003,7 +4544,7 @@ class DecomposeAtenRepeatInterleaveSelfIntOp bool dimIsNone = false; int64_t dim; Value dimValue = op.getDim(); - if (dimValue.getType().isa()) { + if (isa(dimValue.getType())) { dimIsNone = true; dim = inputRank - 1; } else { @@ -3175,6 +4716,11 @@ class DecomposeAtenUnflattenIntOp if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + if (inputShape[dimInt] == Torch::kUnknownSize && + llvm::count(sizesInts, -1) > 0) + return rewriter.notifyMatchFailure( + op, "Unimplemented: dynamic unflatten dim with an inferred size."); + SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) return rewriter.notifyMatchFailure( @@ -3313,37 +4859,50 @@ class DecomposeAtenNanToNumOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenNanToNumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - mlir::FloatType f64Type = rewriter.getF64Type(); Value nan = op.getNan(); Value posinf = op.getPosinf(); Value neginf = op.getNeginf(); - auto baseType = - ValueTensorType::getWithLeastStaticInformation(op.getContext()); - if (dyn_cast_or_null(nan.getDefiningOp())) - nan = rewriter.create( - loc, rewriter.getFloatAttr( - f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); - if (dyn_cast_or_null(posinf.getDefiningOp())) + auto outputType = cast(op.getResult().getType()); + if (!outputType.hasDtype() || + !isa(outputType.getDtype())) { + return rewriter.notifyMatchFailure( + op, "expect output type to have float dtype"); + } + mlir::FloatType outputElementType = + cast(outputType.getDtype()); + + if (isa(nan.getType())) { + nan = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + } + if (isa(posinf.getType())) { posinf = rewriter.create( - loc, rewriter.getFloatAttr( - f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); - if (dyn_cast_or_null(neginf.getDefiningOp())) + loc, rewriter.getF64FloatAttr( + APFloat::getLargest(outputElementType.getFloatSemantics()) + .convertToDouble())); + } + if (isa(neginf.getType())) { neginf = rewriter.create( - loc, - rewriter.getFloatAttr( - f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + loc, rewriter.getF64FloatAttr( + APFloat::getLargest(outputElementType.getFloatSemantics(), + /*Negative=*/true) + .convertToDouble())); + } + + auto compareType = outputType.getWithSizesAndDtype( + outputType.getOptionalSizes(), rewriter.getI1Type()); Value isNan = - rewriter.create(loc, baseType, op.getSelf()); + rewriter.create(loc, compareType, op.getSelf()); Value where = rewriter.create( - loc, baseType, isNan, nan, op.getSelf()); + loc, outputType, isNan, nan, op.getSelf()); Value isposinf = - rewriter.create(loc, baseType, where); + rewriter.create(loc, compareType, where); where = rewriter.create( - loc, baseType, isposinf, posinf, where); + loc, outputType, isposinf, posinf, where); Value isneginf = - rewriter.create(loc, baseType, where); + rewriter.create(loc, compareType, where); rewriter.replaceOpWithNewOp( - op, op.getType(), isneginf, neginf, where); + op, outputType, isneginf, neginf, where); return success(); } }; @@ -3371,17 +4930,135 @@ class DecomposeAtenMaskedFillScalarOp }; } // namespace -// Decompose aten._convolution-like to aten.convolution +// Decompose aten.masked_fill.Tensor into aten.where.self op. namespace { -template -class DecomposeAten_ConvolutionLikeOp - : public OpRewritePattern { +class DecomposeAtenMaskedFillTensorOp + : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConvolutionLikeOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMaskedFillTensorOp op, PatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp( + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + rewriter.replaceOpWithNewOp(op, resType, op.getMask(), + op.getValue(), op.getSelf()); + + return success(); + } +}; +} // namespace + +// Decompose aten.masked_scatter: +// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor: +// mask_int = mask + torch.zeros_like(self) +// prefix_sum = torch.cumsum(mask_int.flatten(), dim=0) +// mask_prefix = torch.clamp(prefix_sum - 1, min=0) +// mask = mask.to(torch.bool) +// source = source.flatten()[mask_prefix].reshape(mask.shape) +// return torch.where(mask, source, self) +namespace { +class DecomposeAtenMaskedScatterOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMaskedScatterOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + Value mask = op.getMask(); + Value source = op.getSource(); + Value self = op.getSelf(); + + auto selfTy = cast(self.getType()); + auto resTy = cast(op.getType()); + auto sourceTy = cast(source.getType()); + + if (!resTy || !resTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + if (!selfTy || !selfTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + if (!sourceTy || !sourceTy.areAllSizesKnown() || !sourceTy.hasDtype()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + + int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes + int64_t sourceNumel = + getTensorNumel(source).value(); // as sourceTy has sizes + int64_t selfRank = selfTy.getSizes().size(); + int64_t sourceRank = sourceTy.getSizes().size(); + + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constNone = rewriter.create(loc); + Value selfLastDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(selfRank - 1)); + Value sourceLastDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(sourceRank - 1)); + + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + auto int64Dtype = getDtypeIntValueForType( + rewriter, loc, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type); + + Value zerosLike = rewriter.create( + loc, selfIntType, self, int64Dtype, constNone, constNone, constNone, + constNone); + Value maskInt = rewriter.create( + loc, selfIntType, mask, zerosLike, constOne); + + auto flattenMaskedType = selfTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, si64Type); + Value maskIntFlatten = rewriter.create( + loc, flattenMaskedType, maskInt, constZero, selfLastDim); + Value prefixSum = rewriter.create( + loc, flattenMaskedType, maskIntFlatten, + /*dim=*/constZero, constNone); + Value prefixSumMinusOne = rewriter.create( + loc, flattenMaskedType, prefixSum, constOne, constOne); + Value maskPrefix = rewriter.create( + loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero, + /*max=*/constNone); + + auto sourceFlattenType = sourceTy.getWithSizesAndDtype( + /*optionalSizes=*/{sourceNumel}, sourceTy.getDtype()); + Value sourceFlatten = rewriter.create( + loc, sourceFlattenType, source, constZero, sourceLastDim); + + auto selectSourceType = sourceTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, sourceTy.getDtype()); + Value selectSource = rewriter.create( + loc, selectSourceType, sourceFlatten, constZero, maskPrefix); + + // Reshape normalized output back to the original input shape + auto selfShape = rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), self); + Value sourceReshape = rewriter.create( + loc, selfTy, selectSource, selfShape); + rewriter.replaceOpWithNewOp(op, resTy, mask, + sourceReshape, self); + return success(); + } +}; +} // namespace + +// Decompose aten._convolution-like to aten.convolution +namespace { +template +class DecomposeAten_ConvolutionLikeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvolutionLikeOp op, + PatternRewriter &rewriter) const override { + + rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(), op.getOutputPadding(), op.getGroups()); @@ -3521,6 +5198,82 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv(1/2/3)d.padding to aten.convolution +namespace { +template +class DecomposeAtenConvPaddingOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConvPaddingOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + + Value weight = op.getWeight(); + std::optional maybeRank = getTensorRank(weight); + if (!maybeRank) { + return rewriter.notifyMatchFailure(op, "expected weight to have a rank"); + } + unsigned rank = *maybeRank; + // first 2 dimensions of weight are out_channels and in_channels / groups + if (rank < 3) + return rewriter.notifyMatchFailure( + op, "ConvPaddingOp weight must be at least 3 dimensional."); + + std::string padding_str; + if (!matchPattern(op.getPadding(), m_TorchConstantStr(padding_str))) + return rewriter.notifyMatchFailure(op, + "padding must be a constant string"); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + SmallVector paddingValues; + if (padding_str == "valid") { + // valid means no padding + for (unsigned iRank = 2; iRank < rank; iRank++) { + paddingValues.push_back(zero); + } + } else { + + SmallVector dilation; + getListConstructElements(op.getDilation(), dilation); + + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value two = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + for (unsigned iRank = 2; iRank < rank; iRank++) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(iRank)); + Value kernelSize = + rewriter.create(loc, weight, dim); + Value kernelSizeMinusOne = + rewriter.create(loc, kernelSize, one); + Value padding = rewriter.create( + loc, dilation[iRank - 2], kernelSizeMinusOne); + padding = rewriter.create(loc, padding, two); + paddingValues.push_back(padding); + } + } + + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + Value padding = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + paddingValues); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), padding, op.getDilation(), cstFalse, emptyList, + op.getGroups()); + + return success(); + } +}; +} // namespace + // Decompose aten.conv3d to aten.convolution namespace { class DecomposeAtenConv3dOp : public OpRewritePattern { @@ -3543,6 +5296,25 @@ class DecomposeAtenConv3dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv_transpose1d to aten.convolution +namespace { +class DecomposeAtenConvTranspose1dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose1dOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); + return success(); + } +}; +} // namespace + // Decompose aten.conv_transpose2d to aten.convolution namespace { class DecomposeAtenConvTranspose2dOp @@ -3562,6 +5334,25 @@ class DecomposeAtenConvTranspose2dOp }; } // namespace +// Decompose aten.conv_transpose3d to aten.convolution +namespace { +class DecomposeAtenConvTranspose3dOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTranspose3dInputOp op, + PatternRewriter &rewriter) const override { + + Value cstTrue = rewriter.create(op.getLoc(), true); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), + /*transposed=*/cstTrue, op.getOutputPadding(), op.getGroups()); + return success(); + } +}; +} // namespace + // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ @@ -3789,10 +5580,9 @@ class DecomposeAtenConvolutionBackwardOp gradOutputViewSizesInt[0] = kUnknownSize; gradOutputViewSizesInt[1] = 1; BaseTensorType gradOutputTypeForView = - gradOutputTy - .getWithSizesAndDtype(llvm::ArrayRef(gradOutputViewSizesInt), - gradOutputTy.getOptionalDtype()) - .cast(); + cast(gradOutputTy.getWithSizesAndDtype( + llvm::ArrayRef(gradOutputViewSizesInt), + gradOutputTy.getOptionalDtype())); Value gradOutputView = rewriter.create( loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); @@ -3820,10 +5610,9 @@ class DecomposeAtenConvolutionBackwardOp } BaseTensorType gradWeightTy = - inputTransposedTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), - inputTransposedTy.getOptionalDtype()) - .cast(); + cast(inputTransposedTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightSizesInt), + inputTransposedTy.getOptionalDtype())); Value numGroup = rewriter.create(loc, input, cstZero); gradWeight = rewriter.create( @@ -3839,10 +5628,9 @@ class DecomposeAtenConvolutionBackwardOp for (unsigned i = 0; i < gradWeightTy.getSizes().size() - 2; i++) { gradWeightSizesInt[i + 2] = weightSizes[i + 2]; BaseTensorType gradWeightNarrowTy = - gradWeightTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightSizesInt), - gradWeightTy.getOptionalDtype()) - .cast(); + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightSizesInt), + gradWeightTy.getOptionalDtype())); Value dim = rewriter.create( loc, rewriter.getI64IntegerAttr(i + 2)); @@ -3872,10 +5660,9 @@ class DecomposeAtenConvolutionBackwardOp gradWeightViewShapeValue); BaseTensorType gradWeightTypeForView = - gradWeightTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightViewShapeInt), - gradWeightTy.getOptionalDtype()) - .cast(); + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightViewShapeInt), + gradWeightTy.getOptionalDtype())); gradWeight = rewriter.create( loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); @@ -3888,10 +5675,9 @@ class DecomposeAtenConvolutionBackwardOp gradWeightViewShapeInt[gradWeightDimsOrder[i]]); } BaseTensorType gradWeightTypeForMoveDim = - gradWeightTy - .getWithSizesAndDtype(llvm::ArrayRef(gradWeightMoveDimShape), - gradWeightTy.getOptionalDtype()) - .cast(); + cast(gradWeightTy.getWithSizesAndDtype( + llvm::ArrayRef(gradWeightMoveDimShape), + gradWeightTy.getOptionalDtype())); gradWeight = rewriter.create( loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, @@ -3911,9 +5697,8 @@ class DecomposeAtenConvolutionBackwardOp Value gradOutputTransposed = rewriter.create( loc, transposedType, gradOutput, cstZero, cstOne); // Convolve input with grad_output. - if (failed( - getTransposedType(op.getResultTypes()[1].cast(), - 0, 1, transposedType))) + if (failed(getTransposedType(cast(op.getResultTypes()[1]), + 0, 1, transposedType))) return failure(); gradWeight = rewriter.create( loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, @@ -3943,6 +5728,240 @@ class DecomposeAtenConvolutionBackwardOp }; } // namespace +/** + * # one dim input + * t = torch.tensor([0, 0, 1, 1, 0, 0] + * # t_flat:[0, 0, 1, 1, 0, 0] + * t_flat = t.flatten(0, 0) + * nonzero_mask = t_flat != 0 + * # nonzero_mask:[0, 0, 1, 1, 0, 0] + * nonzero_mask = nonzero_mask.long() + * # destination_indices:[-1, -1, 0, 1, 1, 1] + * destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + * # destination_indices_clamp:[0, 0, 0, 1, 1, 1] + * destination_indices_clamp = torch.clamp(destination_indices, min=0) + * # iota:[0, 0, 2, 3, 0, 0] + * iota = torch.arange(t_flat.size(0)) * nonzero_mask + * # scatter_self:[0, 0, 0, 0, 0, 0] + * scatter_self = torch.zeros_like(t_flat, dtype=torch.int64) + * # compacted:[2, 3, 0, 0, 0, 0] + * compacted = torch.scatter_add( + * scatter_self, dim=0, index=destination_indices_clamp, src=iota + * ) + * # result_flat:[2, 3] + * result_flat = compacted[: torch.sum(nonzero_mask)] + * + * # multi dim support + * original_shape = t.shape + * # input_shape_tensor:[6] + * input_shape_tensor = torch.tensor(original_shape) + * strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) + * + * one = torch.tensor([1]) + * if(t.dim() > 1): + * slicedStrides = strides[1:-1] + * strides = torch.cat([slicedStrides, one]) + * else: + * strides = one + * # a: tensor([[2], [3]]) torch.Size([2, 1]) + * a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1]) + * # b: tensor([[1]]) torch.Size([1, 1]) + * b = strides.unsqueeze(0) + * # c: tensor([[2], [3]]) torch.Size([2, 1]) + * c = a // b + * # result: tensor([[2], [3]]) torch.Size([2, 1]) + * result = c % input_shape_tensor + */ +class DecomposeAtenNonzeroOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNonzeroOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resultType = cast(op.getType()); + auto intType = resultType.getDtype(); + Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType); + auto constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + std::function makeOneElementList = [&](Value element) { + auto listType = Torch::ListType::get(element.getType()); + return rewriter.create(loc, listType, + ArrayRef{element}); + }; + + Value input = op.getSelf(); + auto inputType = dyn_cast(input.getType()); + int64_t inputRank = inputType.getSizes().size(); + + // t_flat = t.flatten() # torch.flatten(t, 0, 0) + int64_t flattenedSize = 1; + if (inputType.hasSizes()) { + for (auto size : inputType.getSizes()) { + flattenedSize *= size; + } + } else { + flattenedSize = kUnknownSize; + } + + auto flattendInputShape = SmallVector{flattenedSize}; + auto flattenedInputType = rewriter.getType( + flattendInputShape, inputType.getOptionalDtype()); + + // %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 : + auto inputDimsEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value flattenedInput = rewriter.create( + loc, flattenedInputType, input, constantZero /*inputDimsStart*/, + inputDimsEnd /*inputDimsEnd*/); + + // nonzero_mask = (t_flat != 0) + auto boolMaskType = inputType.getWithSizesAndDtype( + flattenedInputType.getOptionalSizes(), rewriter.getI1Type()); + Value boolMask = rewriter.create( + loc, boolMaskType, flattenedInput, constantZero); + + // nonzero_mask = nonzero_mask.int() + Value falseCst = rewriter.create(loc, false); + Value noneCst = rewriter.create(loc); + auto intMaskType = flattenedInputType.getWithSizesAndDtype( + flattenedInputType.getOptionalSizes(), intType); + Value intMask = rewriter.create( + loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst); + + // destination_indices = torch.cumsum(nonzero_mask, 0) - 1 + Value cumulativeSum = rewriter.create( + loc, intMaskType, intMask, constantZero, noneCst); + Value subtracted = rewriter.create( + loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne); + + // destination_indices = torch.clamp(destination_indices, min=0) + Value indices = rewriter.create(loc, intMaskType, + subtracted, constantZero); + + // iota = torch.arange(len(t_flat)) * nonzero_mask + Value end = rewriter.create(loc, flattenedInput, + /*dim=*/constantZero); + Value rangeTensor = rewriter.create( + loc, intMaskType, /*start*/ constantZero, /*end*/ end, + /*step*/ constantOne, noneCst, noneCst, noneCst, noneCst); + Value multiplied = rewriter.create(loc, intMaskType, + rangeTensor, intMask); + + // scatter_self = torch.zeros_like(t, dtype=torch.int64) + // AtenFullLike doesn't support index type so we have to use int. + Value zerosTensor = rewriter.create( + loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst, + noneCst, noneCst); + + // compacted = torch.scatter_add( + // scatter_self, dim=0, index=destination_indices_clamp, src=iota) + Value scatteredTensor = rewriter.create( + loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, + /*index=*/indices, /*src=*/multiplied); + + // result_flat = compacted[:torch.sum(nonzero_mask)] + auto scalarType = ValueTensorType::get(rewriter.getContext(), + ArrayRef{}, intType); + Value sumMask = + rewriter.create(loc, scalarType, intMask, noneCst); + Value numNonzero = rewriter.create(loc, sumMask); + + auto slicedResultType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, intType); + Value slicedResult = + rewriter.create(loc, slicedResultType, + /*self=*/scatteredTensor, + /*dim=*/constantZero, + /*start=*/noneCst, + /*end=*/numNonzero, + /*step=*/constantOne); + + // TODO fix multidim dynamic support. The following code only work for + // static multidim. Convert flattened indices back to multi-dimensional + // indices original_shape = t.shape input_shape_tensor = + // torch.tensor(original_shape) + auto shapeType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank}, intType); + SmallVector shapeValues; + for (int i = 0; i < inputRank; i++) { + auto constantI = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + Value shape = rewriter.create(loc, input, + /*dim=*/constantI); + shapeValues.push_back(shape); + } + Value shapeTensorList = rewriter.create( + loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues); + Value inputShapeTensor = rewriter.create( + loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst); + + // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0) + Value flippedShape = rewriter.create( + loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); + Value cumulativeProduct = rewriter.create( + loc, shapeType, flippedShape, constantZero, noneCst); + Value flippedCumulativeProduct = rewriter.create( + loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); + + // strides = torch.cat([strides[1:-1], torch.tensor([1])]) + auto oneTensorType = ValueTensorType::get(rewriter.getContext(), + SmallVector{1}, intType); + Value oneTensor = rewriter.create( + loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst, + noneCst); + + Value strides; + if (inputRank > 1) { + // strides[1:-1] + auto slicedStrideType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{inputRank - 1}, // sizes + intType); + Value strideSliceEnd = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank)); + Value slicedStrides = rewriter.create( + loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + /*dim*/ constantZero, + /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); + // torch.cat + auto tensorListElementType = Torch::ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize}, intType); + Value tensorList = rewriter.create( + loc, Torch::ListType::get(tensorListElementType), + SmallVector{slicedStrides, oneTensor}); + strides = rewriter.create(loc, shapeType, tensorList, + constantZero); + } else { + // strides[1:-1] is empty + strides = oneTensor; + } + + // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % + // input_shape_tensor + auto unsqueezedResultType = ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize, 1}, intType); + Value unsqueezedResult = rewriter.create( + loc, unsqueezedResultType, slicedResult, constantOne); + + auto unsqueezedStridesType = ValueTensorType::get( + rewriter.getContext(), SmallVector{1, inputRank}, intType); + Value unsqueezedStrides = rewriter.create( + loc, unsqueezedStridesType, strides, constantZero); + + auto dividedBroadcastType = ValueTensorType::get( + rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, + intType); + Value divided = rewriter.create( + loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); + + Value modded = rewriter.create( + loc, resultType, divided, inputShapeTensor); + + rewriter.replaceOp(op, modded); + return success(); + } +}; + // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -3965,7 +5984,7 @@ class DecomposeAtenAddmmOp : public OpRewritePattern { // TODO: Handle integer type operands. auto inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + if (!inputType.hasDtype() || !isa(inputType.getDtype())) { return rewriter.notifyMatchFailure( op, "unimplemented: non-floating point dtype"); } @@ -4027,7 +6046,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { MLIRContext *context = op.getContext(); BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa() || + if (!inputType.hasDtype() || !isa(inputType.getDtype()) || !isNoneOrFloatDtype(context, dtype)) { return rewriter.notifyMatchFailure( op, "only floating-point type is supported"); @@ -4035,7 +6054,7 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { SmallVector dimListElements; if (!getListConstructElements(dimList, dimListElements) && - !dimList.getType().isa()) { + !isa(dimList.getType())) { return rewriter.notifyMatchFailure( op, "expected `dim` to be `None` or constructed from list construct"); } @@ -4117,7 +6136,7 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { return success(); } BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) + if (!inputType.hasDtype() || !isa(inputType.getDtype())) return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); Value noneVal = rewriter.create(loc); @@ -4145,7 +6164,7 @@ class DeomposeAtenNativeDropoutOp Value input = op.getInput(); Value prob = op.getP(); bool train = false; - if (!op.getTrain().getType().isa()) { + if (!isa(op.getTrain().getType())) { if (!matchPattern(op.getTrain(), m_TorchConstantBool(&train))) { return rewriter.notifyMatchFailure( op, "train must be a boolean constant or none"); @@ -4165,7 +6184,7 @@ class DeomposeAtenNativeDropoutOp return success(); } BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + if (!inputType.hasDtype() || !isa(inputType.getDtype())) { return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); } @@ -4234,7 +6253,7 @@ class DecomposeAtenStdOp : public OpRewritePattern { Value self = op.getSelf(); BaseTensorType inputTensorTy = cast(self.getType()); if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { + !isa(inputTensorTy.getDtype())) { return rewriter.notifyMatchFailure(op, "Only aten.std support floating type"); } @@ -4290,7 +6309,7 @@ class DecomposeAtenStdDimOp : public OpRewritePattern { Value self = op.getSelf(); BaseTensorType inputTensorType = cast(self.getType()); if (!inputTensorType.hasDtype() || - !inputTensorType.getDtype().isa()) { + !isa(inputTensorType.getDtype())) { return rewriter.notifyMatchFailure( op, "aten.std.dim expects input tensor of floating-point type"); } @@ -4304,6 +6323,72 @@ class DecomposeAtenStdDimOp : public OpRewritePattern { }; } // namespace +// Decompose aten.rot90 +// github: +// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L3830 +namespace { +class DecomposeAtenRot90Op : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRot90Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + // Convert dims from Value to SmallVector. + SmallVector dims; + if (!getListConstructElements(op.getDims(), dims)) + return rewriter.notifyMatchFailure( + op, "unimplemented: dims not list of Scalar"); + + // Convert k from Value to int + int64_t k; + if (!matchPattern(op.getK(), m_TorchConstantInt(&k))) + return rewriter.notifyMatchFailure(op, + "Unimplemented: k not constant int"); + + k = (k % 4 + 4) % + 4; // This is equal to python code k = k % 4, because python and c++ + // have different implementation for operand %. + + if (k == 1) { + Value flipDimList = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + ArrayRef{dims[1]}); + + Value flip = + rewriter.create(loc, self.getType(), self, flipDimList); + + rewriter.replaceOpWithNewOp( + op, op.getType(), flip, dims[0], dims[1]); + } else if (k == 2) { + rewriter.replaceOpWithNewOp(op, op.getType(), self, + op.getDims()); + } else if (k == 3) { + Value flipDimList = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + ArrayRef{dims[0]}); + + Value flip = + rewriter.create(loc, self.getType(), self, flipDimList); + + rewriter.replaceOpWithNewOp( + op, op.getType(), flip, dims[0], dims[1]); + } else { + rewriter.replaceOpWithNewOp( + op, op.getType(), self, + /*memory_format=*/ + rewriter.create(loc, + rewriter.getI64IntegerAttr(0))); + } + + return success(); + } +}; +} // namespace + // Decompose aten.std.correction to sqrt(var.correction(x)) namespace { class DecomposeAtenStdCorrectionOp @@ -4315,7 +6400,7 @@ class DecomposeAtenStdCorrectionOp Value self = op.getSelf(); BaseTensorType inputTensorType = cast(self.getType()); if (!inputTensorType.hasDtype() || - !inputTensorType.getDtype().isa()) { + !isa(inputTensorType.getDtype())) { return rewriter.notifyMatchFailure( op, "aten.std.correction expects input tensor of floating-point type"); @@ -4408,7 +6493,7 @@ class DecomposeAtenRandLikeOp : public OpRewritePattern { Value input = op.getSelf(); Type resultType = op.getType(); auto inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.getDtype().isa()) { + if (!inputType.hasDtype() || !isa(inputType.getDtype())) { return rewriter.notifyMatchFailure(op, "only support floating-point type"); } @@ -4449,7 +6534,7 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, op, "can't decompose bernoulli like ops without sizes or dtype"); } // The `prob` is expected to be a float type tensor. - if (!probType.getDtype().isa()) { + if (!isa(probType.getDtype())) { return rewriter.notifyMatchFailure( op, "probabilities must be a float type tensor"); } @@ -4484,7 +6569,7 @@ class DecomposeAtenBernoulliOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getSelf(); - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4542,7 +6627,7 @@ class DecomposeAtenBernoulliTensorOp Location loc = op.getLoc(); Value input = op.getSelf(); Value prob = op.getP(); - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4567,7 +6652,7 @@ class DecomposeAtenExponentialOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenExponentialOp op, PatternRewriter &rewriter) const override { - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4608,7 +6693,7 @@ class DecomposeAtenNormalFunctionalOp using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, PatternRewriter &rewriter) const override { - if (!op.getGenerator().getType().isa()) + if (!isa(op.getGenerator().getType())) return rewriter.notifyMatchFailure( op, "The generator has to be None because only global default " "generator is supported"); @@ -4817,6 +6902,63 @@ class DecomposeAtenInstanceNormOp }; } // namespace +namespace { +class DecomposeAten_WeightNormInterfaceOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_WeightNormInterfaceOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value v = op.getV(); + Value g = op.getG(); + Value dim = op.getDim(); + + auto inputType = cast(v.getType()); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected input to have sizes"); + + if (!cast(dim.getDefiningOp())) + return rewriter.notifyMatchFailure(op, "dim is not a ConstantIntOp"); + + auto sizes = inputType.getSizes(); + SmallVector keepDims; + for (int64_t i = 0; i < static_cast(sizes.size()); ++i) { + if (i != + static_cast(dim.getDefiningOp().getValue())) + keepDims.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + } + + Value ord = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + Value keepdim = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value dtypeNone = rewriter.create(loc); + + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + keepDims); + + Value norm = rewriter.create( + loc, v.getType(), v, ord, dimList, keepdim, dtypeNone); + + auto vShape = rewriter.create( + loc, Torch::ListType::get(rewriter.getI64Type()), v); + + Value gDivNorm = + rewriter.create(loc, g.getType(), g, norm); + Value broadcastedGDivNorm = + rewriter.create(loc, v.getType(), gDivNorm, vShape); + Value vMulBroadcastedGDivNorm = rewriter.create( + loc, v.getType(), v, broadcastedGDivNorm); + + rewriter.replaceOp(op, ArrayRef{vMulBroadcastedGDivNorm, norm}); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { @@ -4826,7 +6968,7 @@ class DecomposeAtenNativeLayerNormOp Location loc = op.getLoc(); auto context = op.getContext(); - auto inputTy = cast(op.getInput().getType()); + auto inputTy = cast(op.getInput().getType()); if (!inputTy.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); @@ -4881,15 +7023,27 @@ class DecomposeAtenNativeLayerNormOp loc, inputTy, inputRsqrtVar, op.getInput()); Value inputNormalized = rewriter.create( loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); + // Convert resultType if dtype is different + auto resultTensorType = + dyn_cast(op.getResult(0).getType()); + if (inputTy.getDtype() != resultTensorType.getDtype()) { + Value dtypeValue = Torch::getDtypeIntValueForType( + rewriter, loc, resultTensorType.getDtype()); + Value cstFalse = rewriter.create(loc, false); + inputNormalized = rewriter.create( + loc, resultTensorType, inputNormalized, + /*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } Value out = rewriter.create( loc, op.getResult(0).getType(), inputNormalized); Value weight = op.getWeight(); Value bias = op.getBias(); - if (!weight.getType().isa()) { + if (!isa(weight.getType())) { out = rewriter.create(loc, out.getType(), out, weight); } - if (!bias.getType().isa()) { + if (!isa(bias.getType())) { out = rewriter.create(loc, out.getType(), out, bias, one); } @@ -5014,7 +7168,6 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenGroupNormOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); Value input = op.getInput(); Value weight = op.getWeight(); @@ -5022,11 +7175,23 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern { Value numGroups = op.getNumGroups(); Value eps = op.getEps(); + int64_t numGroupsInt; + if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: num_groups must be a constant int"); + Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + + auto inputType = cast(input.getType()); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure(op, "input should have sizes."); + + SmallVector baseTypeSizes{inputType.getSizes()[0], numGroupsInt}; + auto baseType = inputType.getWithSizesAndDtype( + baseTypeSizes, inputType.getOptionalDtype()); Value N = rewriter.create(loc, input, cstZero); Value C = rewriter.create(loc, input, cstOne); @@ -5080,7 +7245,6 @@ class DecomposeAtenNativeGroupNormOp rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); Value cstTrue = rewriter.create(loc, true); Value cstFalse = rewriter.create(loc, false); - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); // GroupNorm requires the channel dimension (C) to be exactly divisible by // the number of groups. @@ -5094,12 +7258,34 @@ class DecomposeAtenNativeGroupNormOp "the number of groups")); // Reshape the input tensor to (N, numGroups, -1) to apply normalization. + int64_t numGroupsInt; + if (!matchPattern(numGroups, m_TorchConstantInt(&numGroupsInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: num_groups must be a constant int"); + SmallVector newShape; + SmallVector inputShapeInt{inputType.getSizes()}; + SmallVector reshapeInputShape{inputShapeInt[0], numGroupsInt}; + int64_t reshapeInputLastDim = 1; + for (size_t i = 1; i < inputShapeInt.size(); i++) { + if (inputShapeInt[i] == Torch::kUnknownSize) { + reshapeInputLastDim = Torch::kUnknownSize; + break; + } + reshapeInputLastDim *= inputShapeInt[i]; + } + reshapeInputLastDim = reshapeInputLastDim == Torch::kUnknownSize + ? reshapeInputLastDim + : reshapeInputLastDim / numGroupsInt; + reshapeInputShape.push_back(reshapeInputLastDim); + newShape.push_back(rewriter.create(loc, input, cstZero)); newShape.push_back(numGroups); newShape.push_back(cstNegtiveOne); + Type reshapeInputType = inputType.getWithSizesAndDtype( + reshapeInputShape, inputType.getOptionalDtype()); Value reshapedInput = rewriter.create( - loc, baseType, input, + loc, reshapeInputType, input, rewriter.create( loc, Torch::ListType::get(IntType::get(context)), newShape)); @@ -5108,21 +7294,28 @@ class DecomposeAtenNativeGroupNormOp Value dimList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), ArrayRef{cstNegtiveOne}); - auto mean = rewriter.create( - loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue, - /*dtype=*/none); - auto var = rewriter.create( - loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse, - /*keepdim=*/cstTrue); + + reshapeInputShape[2] = 1; + Type reductionType = inputType.getWithSizesAndDtype( + reshapeInputShape, inputType.getOptionalDtype()); + auto mean = + rewriter.create(loc, reductionType, reshapedInput, + /*dims=*/dimList, /*keepdim=*/cstTrue, + /*dtype=*/none); + auto var = + rewriter.create(loc, reductionType, reshapedInput, + /*dims=*/dimList, /*unbiased=*/cstFalse, + /*keepdim=*/cstTrue); // Compute the normalized output: (input - mean) * rsqrt(var + eps) - auto varPlusEps = rewriter.create(loc, baseType, var, eps, - /*alpha=*/cstOne); - auto invStd = rewriter.create(loc, baseType, varPlusEps); + auto varPlusEps = + rewriter.create(loc, reductionType, var, eps, + /*alpha=*/cstOne); + auto invStd = rewriter.create(loc, reductionType, varPlusEps); auto inputSubMean = rewriter.create( - loc, baseType, reshapedInput, mean, /*alpha=*/cstOne); - auto normalizedOutput = - rewriter.create(loc, baseType, inputSubMean, invStd); + loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne); + auto normalizedOutput = rewriter.create( + loc, reshapeInputType, inputSubMean, invStd); // Reshape normalized output back to the original input shape auto inputShape = rewriter.create( @@ -5133,22 +7326,26 @@ class DecomposeAtenNativeGroupNormOp // Apply weight and bias if they are not None // Reshape weight and bias to C,1,1,... SmallVector viewShape = {channel}; + SmallVector viewShapeInt{inputShapeInt[1]}; for (unsigned i = 2; i < inputType.getSizes().size(); i++) { viewShape.push_back(cstOne); + viewShapeInt.push_back(1); } Value viewShapeSizeList = rewriter.create( loc, ListType::get(IntType::get(context)), viewShape); + Type viewType = inputType.getWithSizesAndDtype( + viewShapeInt, inputType.getOptionalDtype()); Value groupNormOutput = reshapedOutput; - if (!weight.getType().isa()) { + if (!isa(weight.getType())) { auto weightReshaped = rewriter.create( - loc, baseType, weight, /*shape=*/viewShapeSizeList); + loc, viewType, weight, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, weightReshaped); } - if (!bias.getType().isa()) { + if (!isa(bias.getType())) { auto biasReshaped = rewriter.create( - loc, baseType, bias, /*shape=*/viewShapeSizeList); + loc, viewType, bias, /*shape=*/viewShapeSizeList); groupNormOutput = rewriter.create( loc, inputType, groupNormOutput, biasReshaped, /*alpha=*/cstOne); @@ -5199,8 +7396,8 @@ class DecomposeAtenNativeBatchNormOp // In the inference mode, the `runningMean` and `runningVar` must not be // None. - if (runningMean.getType().isa() || - runningVar.getType().isa()) + if (isa(runningMean.getType()) || + isa(runningVar.getType())) return rewriter.notifyMatchFailure( op, "running stats must not be None in inference mode"); @@ -5256,7 +7453,7 @@ class DecomposeAtenNativeBatchNormOp // 2. bias = bias.view(1, C, 1?, 1?, 1?) // 3. output = normalizedInput * weight + bias Value batchNormOutput = normalizedInput; - if (!weight.getType().isa()) { + if (!isa(weight.getType())) { // Rank of `weight` must be exactly 1. std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) @@ -5266,7 +7463,7 @@ class DecomposeAtenNativeBatchNormOp batchNormOutput = rewriter.create( loc, batchNormOutput.getType(), batchNormOutput, weight); } - if (!bias.getType().isa()) { + if (!isa(bias.getType())) { // Rank of `bias` must be exactly 1. std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) @@ -5346,7 +7543,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -5403,38 +7600,57 @@ class DecomposeAtenLinearOp : public OpRewritePattern { Value bias = op.getBias(); BaseTensorType inputType = cast(input.getType()); - if (!inputType.hasSizes() || inputType.getSizes().size() < 2) - return rewriter.notifyMatchFailure( - op, "expected input to be rank 2 or greater"); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected input to have sizes"); BaseTensorType weightType = cast(weight.getType()); - // `weight` must be a rank 2 matrix. - if (!weightType.hasSizes() || weightType.getSizes().size() != 2) - return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); - - SmallVector transposeShape = - llvm::to_vector(llvm::reverse(weightType.getSizes())); - Type transposeType = weightType.getWithSizesAndDtype( - llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); - Value transposeWeight = - rewriter.create(loc, transposeType, weight); - - Value matmul = rewriter.create(loc, op.getType(), input, - transposeWeight); - if (bias.getType().isa()) { - rewriter.replaceOp(op, matmul); - return success(); - } + if (!weightType.hasSizes()) + return rewriter.notifyMatchFailure(op, "expected weight to have sizes"); + + auto transposeWeight = [&]() -> Value { + SmallVector transposeShape = + llvm::to_vector(llvm::reverse(weightType.getSizes())); + Type transposeType = weightType.getWithSizesAndDtype( + llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); + Value transposeWeight = + rewriter.create(loc, transposeType, weight); + return transposeWeight; + }; - BaseTensorType biasType = cast(bias.getType()); - if (!biasType.hasSizes() || biasType.getSizes().size() != 1) - return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); + if (isa(bias.getType())) { + auto weightRank = weightType.getSizes().size(); + if (weightRank > 2 || weightRank <= 0) + return rewriter.notifyMatchFailure( + op, "expected weight's rank <= 2 && >= 1"); + if (weightRank == 1) { + rewriter.replaceOpWithNewOp(op, op.getType(), input, + weight); + return success(); + } else if (weightRank == 2) { + rewriter.replaceOpWithNewOp(op, op.getType(), input, + transposeWeight()); + return success(); + } + llvm_unreachable("unsupported weightRank"); + } else { + BaseTensorType biasType = cast(bias.getType()); + if (!biasType.hasSizes() || biasType.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); - Value alpha = - rewriter.create(loc, rewriter.getF64FloatAttr(1)); - rewriter.replaceOpWithNewOp(op, op.getType(), matmul, - op.getBias(), alpha); - return success(); + // `weight` must be a rank 2 matrix. + auto weightRank = weightType.getSizes().size(); + if (weightRank != 2) + return rewriter.notifyMatchFailure(op, + "expected weight to be a rank 2"); + + Value matmul = rewriter.create(loc, op.getType(), input, + transposeWeight()); + Value alpha = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp(op, op.getType(), matmul, + op.getBias(), alpha); + return success(); + } } }; } // namespace @@ -5505,7 +7721,7 @@ class DecomposeAtenNewFullOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenNewFullOp op, PatternRewriter &rewriter) const override { Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -5601,7 +7817,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value noneVal = rewriter.create(op.getLoc()); Value dtype = op.getDtype(); - if (dtype.getType().isa()) { + if (isa(dtype.getType())) { BaseTensorType tensorType = cast(op.getSelf().getType()); if (!tensorType.hasDtype()) { return rewriter.notifyMatchFailure( @@ -5624,31 +7840,113 @@ class DecomposeAtenPadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenPadOp op, PatternRewriter &rewriter) const override { + std::string mode; + if (!matchPattern(op.getMode(), m_TorchConstantStr(mode))) + return rewriter.notifyMatchFailure(op, "mode must be a constant string"); - Value value = op.getValue(); - if (value.getType().isa()) - return rewriter.notifyMatchFailure(op, "optional type not supported"); - if (value.getType().isa()) - value = rewriter.create( - op.getLoc(), rewriter.getF64FloatAttr(0)); + if (mode == "constant") { + Value value = op.getValue(); + if (isa(value.getType())) + return rewriter.notifyMatchFailure(op, "optional type not supported"); + if (isa(value.getType())) + value = rewriter.create( + op.getLoc(), rewriter.getF64FloatAttr(0)); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getPad(), value); - return success(); - } -}; -} // namespace + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getPad(), value); + return success(); + } -namespace { -// Decompose `aten.to.dtypeLayout` op into `aten.to.dtype` op. -class DecomposeAtenToDtypeLayoutOp - : public OpRewritePattern { -public: + SmallVector padValues; + if (!getListConstructElements(op.getPad(), padValues)) + return failure(); + SmallVector padInts; + Value usefulPads = op.getPad(); + uint64_t usefulPadIndexEnd = padValues.size(); + + // try to reduce the number of padding dims if possible + if (matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) { + if ((padInts.size() % 2) == 1) + return rewriter.notifyMatchFailure(op, + "expected an even number of pads"); + + for (uint64_t i = padInts.size() - 1; i > 0; i -= 2) { + if (padInts[i] != 0 || padInts[i - 1] != 0) + break; + usefulPadIndexEnd = i - 1; + } + if (usefulPadIndexEnd == 0) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + } + + // we don't have support for 1-D replicate pad, so pass it as 2d if + // possible. + // TODO: add support for AtenReplicatePad1dOp and remove this. + if (mode == "replicate" && usefulPadIndexEnd == 2 && padValues.size() >= 4) + usefulPadIndexEnd = 4; + + // make a new list of padding ints if dimensionality reduction can be + // performed + if (usefulPadIndexEnd < padValues.size()) { + ArrayRef usefulPadValues(padValues.begin(), + padValues.begin() + usefulPadIndexEnd); + usefulPads = rewriter.create( + op.getLoc(), + rewriter.getType(rewriter.getType()), + usefulPadValues); + } + + uint64_t numPadDims = usefulPadIndexEnd / 2; + + if (mode == "reflect") { + // only support for relectionpad 1d and 2d + switch (numPadDims) { + case 1: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + case 2: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + case 3: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + default: + return rewriter.notifyMatchFailure( + op, "unsupported number of dims for 'reflect' mode: " + + std::to_string(numPadDims)); + } + return success(); + } + + if (mode == "replicate") { + // only support for replication pad 2d + if (numPadDims != 2) + return failure(); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + return success(); + } + + return rewriter.notifyMatchFailure(op, "unsupported mode: " + mode); + } +}; +} // namespace + +namespace { +// Decompose `aten.to.dtypeLayout` op into `aten.to.dtype` op. +class DecomposeAtenToDtypeLayoutOp + : public OpRewritePattern { +public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenToDtypeLayoutOp op, PatternRewriter &rewriter) const override { // TODO: Add support for pinMemory arg equal to `True`. - if (!op.getPinMemory().getType().isa()) { + if (!isa(op.getPinMemory().getType())) { bool pinMemory; if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory))) return rewriter.notifyMatchFailure( @@ -5659,7 +7957,7 @@ class DecomposeAtenToDtypeLayoutOp } // TODO: Add support for device arg other than cpu. - if (!op.getDevice().getType().isa()) { + if (!isa(op.getDevice().getType())) { std::string device; if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device))) return rewriter.notifyMatchFailure( @@ -5671,7 +7969,7 @@ class DecomposeAtenToDtypeLayoutOp // TODO: Add support for non-strided layout. // torch.layout is by default strided i.e. 0. - if (!op.getLayout().getType().isa()) { + if (!isa(op.getLayout().getType())) { int64_t tensorLayout; if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout))) return rewriter.notifyMatchFailure( @@ -5735,6 +8033,94 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices` +// op. +class DecomposeAtenAdaptiveMaxPool1dOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getSelf(); + std::optional maybeRank = getTensorRank(input); + if (!maybeRank) { + return rewriter.notifyMatchFailure(op, "expected input to have a rank"); + } + unsigned rank = *maybeRank; + Value sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(rank - 1)); + Value inputSize = rewriter.create(loc, input, sizeDim); + + Value outputShape = op.getOutputSize(); + SmallVector outputShapeSizesTorchInt; + getListConstructElements(outputShape, outputShapeSizesTorchInt); + Value outputSize = outputShapeSizesTorchInt[0]; + + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constantZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constantFalse = rewriter.create(loc, false); + + int64_t outputSizeInt; + if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { + return rewriter.notifyMatchFailure( + op, "the output size of adaptive_max_pool1d must be a constant int"); + } + + SmallVector kernelSize; + if (outputSizeInt == 1) { + BaseTensorType inputTensorType = cast(input.getType()); + ArrayRef inputShape = inputTensorType.getSizes(); + kernelSize.push_back( + inputShape[rank - 1] == kUnknownSize + ? inputSize + : rewriter.create( + loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); + } else { + if (!isAssumingStrictSymbolicShapes(rewriter)) { + Value cond = rewriter.create(loc, inputSize, outputSize); + rewriter.create( + loc, cond, + "unimplemented: only support cases where input and output size are " + "equal for non-unit output size"); + } + kernelSize.push_back(constantOne); + } + + Value kernelSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + Value paddingSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantZero}); + Value dialationList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + + if (op.getResult(1).use_empty()) { + auto maxPool = rewriter.create( + loc, op.getType(0), input, kernelSizeList, strideList, + paddingSizeList, dialationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, {maxPool.getResult(), Value()}); + } else { + auto maxPool = rewriter.create( + loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList, + paddingSizeList, dialationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, maxPool.getResults()); + } + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. @@ -5857,32 +8243,80 @@ class DecomposeAtenAdaptiveAvgPool2dOp getListConstructElements(outputShape, outputShapeSizesTorchInt); // TODO: Add support for cases other than: - // inH % outH != 0 or inW % outW != 0 - + // inH % outH != 0 or inW % outW != 0 where + // the stride/kernel size is not fixed. + // The following logic of stride/kernel size derivation is consistent + // with torch/_decomp/decomposations.py:adaptive_avg_pool2d. Value constantZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); + Value constantOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); Value constantFalse = rewriter.create(loc, false); Value constantTrue = rewriter.create(loc, true); Value constantNone = rewriter.create(loc); - SmallVector kernelSize; + SmallVector strideSize; + SmallVector kernelSize; for (unsigned i = 0; i < inputHW.size(); i++) { Value remainder = rewriter.create( loc, inputHW[i], outputShapeSizesTorchInt[i]); - Value cond = rewriter.create(loc, remainder, constantZero); - rewriter.create(loc, cond, - "unimplemented: only support cases " - "input size is an integer multiple of " - "output size"); - Value stride = rewriter.create( + + // Filter cases with fixed stride size. + Value cond1 = rewriter.create( + loc, outputShapeSizesTorchInt[i], + rewriter.create( + loc, remainder, + rewriter.create( + loc, outputShapeSizesTorchInt[i], constantOne))); + rewriter.create( + loc, cond1, + "unimplemented: only support cases with fixed stride size."); + + // Filter cases with fixed kernel size. + // cond2: whether input_size % output_size == 0. + Value cond2 = + rewriter.create(loc, remainder, constantZero); + // cond3: whether output_size % (input_size % output_size) == 0. + // To avoid potential crash (eg. tosa) happens,choose to mod 1 (add + // offset) when remainder equals 0, which has no side effect on + // effectiveness. + Value offset = rewriter.create( + loc, rewriter.create( + loc, rewriter.create(loc, remainder))); + Value remainder_not_zero = + rewriter.create(loc, remainder, offset); + Value cond3 = rewriter.create( + loc, + rewriter.create( + loc, outputShapeSizesTorchInt[i], remainder_not_zero), + constantZero); + Value cond = rewriter.create(loc, cond2, cond3); + + rewriter.create( + loc, cond, + "unimplemented: only support cases with fixed kernel size."); + + Value stride = rewriter.create( + loc, inputHW[i], outputShapeSizesTorchInt[i]); + strideSize.emplace_back(stride); + + Value kernel = rewriter.create( loc, inputHW[i], outputShapeSizesTorchInt[i]); - Value kernelSizeValue = stride; - kernelSize.push_back(kernelSizeValue); + + // When remainder equals 0, it is no need for kernel to add 1 + // and just keep the same as stride, otherwise it is necessary + // to add 1 (torch/_decomp/decomposations.py:adaptive_avg_pool2d). + Value boolMod = rewriter.create(loc, remainder); + Value intMod = rewriter.create(loc, boolMod); + + kernel = rewriter.create(loc, kernel, intMod); + kernelSize.emplace_back(kernel); } Value kernelSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - Value strideList = kernelSizeList; + Value strideList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), strideSize); Value paddingSizeList = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); @@ -5939,6 +8373,20 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenRad2degOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenRad2degOp op, + PatternRewriter &rewriter) const override { + Value constant180OverPi = rewriter.create( + op.getLoc(), rewriter.getF64FloatAttr(180 / 3.14159)); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + constant180OverPi); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenCosineSimilarityOp : public OpRewritePattern { @@ -6028,6 +8476,160 @@ class DecomposeAtenTruncOp : public OpRewritePattern { }; } // namespace +namespace { +// decompose `signbit(x)` to `view.dtype(x, si32/si64) < 0 ` +class DecomposeAtenSignbitOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSignbitOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto operandTy = dyn_cast(self.getType()); + auto resultTy = dyn_cast(op.getType()); + if (!operandTy || !operandTy.hasDtype() || !resultTy || + !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, + "operand and result must have dtype"); + } + + if (isa(operandTy.getDtype())) { + mlir::IntegerType intType = rewriter.getIntegerType( + operandTy.getDtype().getIntOrFloatBitWidth(), /*isSigned*/ true); + Value dtype = getDtypeIntValueForType(rewriter, loc, intType); + Value view = rewriter.create( + loc, + operandTy.getWithSizesAndDtype(operandTy.getOptionalSizes(), intType), + self, dtype); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value shift = rewriter.create(loc, resultTy, view, zero); + rewriter.replaceOp(op, shift); + return success(); + } else if (isa(operandTy.getDtype())) { + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value shift = rewriter.create(loc, resultTy, self, zero); + rewriter.replaceOp(op, shift); + } + return failure(); + } +}; +} // namespace + +namespace { +// decompose `frac(x)` to `x - trunc(x)` +class DecomposeAtenFracOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFracOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto resultTy = op.getType(); + + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value trunc = rewriter.create(loc, resultTy, self); + rewriter.replaceOpWithNewOp(op, resultTy, self, trunc, + /*alpha=*/one); + return success(); + } +}; +} // namespace + +namespace { +// decompose `copysign(x, y)` to `signbit(y) ? -abs(x) : abs(x)` +class DecomposeAtenCopysignTensorOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenCopysignTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + auto selfTy = self.getType(); + auto otherTy = cast(other.getType()); + auto resultTy = op.getType(); + + Value signbit = rewriter.create( + loc, + otherTy.getWithSizesAndDtype(otherTy.getOptionalSizes(), + rewriter.getI1Type()), + other); + Value abs = rewriter.create(loc, selfTy, self); + Value neg = rewriter.create(loc, selfTy, abs); + rewriter.replaceOpWithNewOp(op, resultTy, signbit, neg, + abs); + return success(); + } +}; +} // namespace + +namespace { +// decompose `ldexp(x, y)` to `x * 2^y` +class DecomposeAtenLdexpTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLdexpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + + auto otherTy = dyn_cast(other.getType()); + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + Value exp2 = rewriter.create( + loc, + resultTy.getWithSizesAndDtype(otherTy.getOptionalSizes(), + resultTy.getDtype()), + other); + rewriter.replaceOpWithNewOp(op, resultTy, self, exp2); + return success(); + } +}; +} // namespace + +namespace { +// decompose `fmod(x, y)` to `x - trunc(x/y) * y` +class DecomposeAtenFmodTensorOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFmodTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + if (isa(resultTy.getDtype())) { + Value div = rewriter.create(loc, resultTy, self, other); + Value mul = rewriter.create(loc, resultTy, div, other); + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + rewriter.replaceOpWithNewOp(op, resultTy, self, mul, + alpha); + return success(); + } else if (isa(resultTy.getDtype())) { + Value div = rewriter.create(loc, resultTy, self, other); + Value trunc = rewriter.create(loc, resultTy, div); + Value mul = rewriter.create(loc, resultTy, trunc, other); + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + rewriter.replaceOpWithNewOp(op, resultTy, self, mul, + alpha); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // `aten.add.Tensor` op. @@ -6137,7 +8739,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, Type newOutputType = outputTensorType.getWithSizesAndDtype( outputTensorType.getSizes(), rewriter.getF64Type()); if (!inputTensorTy.hasDtype() || - !inputTensorTy.getDtype().isa()) { + !isa(inputTensorTy.getDtype())) { return rewriter.notifyMatchFailure( op, "support floating-point type input only"); } @@ -6274,14 +8876,14 @@ class DecomposeAtenVarCorrectionOp PatternRewriter &rewriter) const override { int64_t correctionValInt; double correctionValFloat = 1.0; - if (!op.getCorrection().getType().isa()) { - if (op.getCorrection().getType().isa()) { + if (!isa(op.getCorrection().getType())) { + if (isa(op.getCorrection().getType())) { if (!matchPattern(op.getCorrection(), m_TorchConstantFloat(&correctionValFloat))) return rewriter.notifyMatchFailure( op, "Only support constant int or float correction value for " "aten.var"); - } else if (op.getCorrection().getType().isa()) { + } else if (isa(op.getCorrection().getType())) { if (!matchPattern(op.getCorrection(), m_TorchConstantInt(&correctionValInt))) return rewriter.notifyMatchFailure( @@ -6365,7 +8967,6 @@ class DecomposeAten_EmbeddingBagOp rewriter.replaceOpWithNewOp( op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, sparse, perSampleWeights, includeLastOffset, paddingIdx); - return success(); } }; @@ -6408,11 +9009,9 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { if (!inputType.hasSizes()) return rewriter.notifyMatchFailure( op, "Expected the input tensor to have sizes"); - BaseTensorType subType = - inputType - .getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), - resultType.getOptionalDtype()) - .cast(); + BaseTensorType subType = cast( + inputType.getWithSizesAndDtype(llvm::ArrayRef(inputType.getSizes()), + resultType.getOptionalDtype())); Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); @@ -6438,6 +9037,71 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenL1LossOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenL1LossOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes() || !selfTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected self to be a tensor with sizes and a dtype"); + } + + Value target = op.getTarget(); + auto targetTy = dyn_cast(target.getType()); + if (!targetTy || !targetTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected target to be a tensor with sizes and a dtype"); + } + + auto outTy = dyn_cast(op.getType()); + if (!outTy || !outTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected output type to be a tensor with a dtype"); + } + + auto outDtype = outTy.getDtype(); + if (selfTy.getDtype() != outDtype) { + self = convertTensorToDtype(rewriter, loc, self, outDtype); + } + if (targetTy.getDtype() != outDtype) { + target = convertTensorToDtype(rewriter, loc, target, outDtype); + } + + Value reduction = op.getReduction(); + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure( + op, "Expected reduction to be a constant int"); + } + + auto subTy = outTy.getWithSizesAndDtype(selfTy.getSizes(), outDtype); + Value sub = createTensorSub(rewriter, loc, subTy, self, target); + Value abs = rewriter.create(loc, subTy, sub); + + if (reductionInt == 0) { + rewriter.replaceOp(op, abs); + } else if (reductionInt == 1) { + Value none = rewriter.create(loc); + Value sum = rewriter.create(loc, outTy, abs, none); + Value numel = rewriter.create(loc, abs); + Value mean = rewriter.create(loc, outTy, sum, numel); + rewriter.replaceOp(op, mean); + } else { + Value none = rewriter.create(loc); + Value sum = rewriter.create(loc, outTy, abs, none); + rewriter.replaceOp(op, sum); + } + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op class DecomposeAtenNormScalarOptDimOp @@ -6449,7 +9113,7 @@ class DecomposeAtenNormScalarOptDimOp Location loc = op->getLoc(); Value none = rewriter.create(loc); Value ord = op.getP(); - if (ord.getType().isa()) { + if (isa(ord.getType())) { ord = rewriter.create( loc, rewriter.getF64FloatAttr(2.0)); } @@ -6492,10 +9156,8 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { loc, rewriter.getF64FloatAttr((double)cstHigh)); BaseTensorType floatResultType = - resultTensorType - .getWithSizesAndDtype(resultTensorType.getSizes(), - rewriter.getF32Type()) - .cast(); + cast(resultTensorType.getWithSizesAndDtype( + resultTensorType.getSizes(), rewriter.getF32Type())); Value emptyTensor = rewriter.create( loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), @@ -6587,7 +9249,7 @@ class DecomposePrimsVarOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsVarOp op, PatternRewriter &rewriter) const override { - if (!op.getOutputDtype().getType().isa()) + if (!isa(op.getOutputDtype().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for prims::var op"); Value cstFalse = rewriter.create(op.getLoc(), false); @@ -6699,7 +9361,7 @@ class DecomposeAtenRandnLikeOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenRandnLikeOp op, PatternRewriter &rewriter) const override { // Only `none`, `contiguous` and `preserve` memory_format is supported. - if (!op.getMemoryFormat().getType().isa()) { + if (!isa(op.getMemoryFormat().getType())) { int64_t memoryFormat; if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) @@ -6767,7 +9429,6 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { Location loc = op.getLoc(); MLIRContext *context = getContext(); - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value none = rewriter.create(loc); Value falseVal = rewriter.create(loc, false); Value zero = @@ -6777,13 +9438,25 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { Value addStart; int64_t steps; + auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); + auto fp32Type = rewriter.getF32Type(); + auto arangeIntType = + getTensorTypeFromShapeValues({op.getSteps()}, si64Type); + auto arangeFp32Type = + getTensorTypeFromShapeValues({op.getSteps()}, fp32Type); if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) { // specically handle steps == 1 Value arange = rewriter.create( - loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), - op.getDevice(), op.getPinMemory()); - addStart = rewriter.create(loc, baseType, arange, - op.getStart(), one); + loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none, + op.getLayout(), op.getDevice(), op.getPinMemory()); + if (isa(op.getEnd().getType()) || + isa(op.getStart().getType())) { + addStart = rewriter.create(loc, arangeFp32Type, arange, + op.getStart(), one); + } else { + addStart = rewriter.create(loc, arangeIntType, arange, + op.getStart(), one); + } } else { // handle steps != 1 or dynamic steps Value neOrNot = rewriter.create(loc, op.getSteps(), one); @@ -6792,12 +9465,12 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { rewriter.getStringAttr("linspace's dynamic steps must not be 1")); // create arange: [0, ..., steps - 1] Value arange = rewriter.create( - loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), - op.getDevice(), op.getPinMemory()); + loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none, + op.getLayout(), op.getDevice(), op.getPinMemory()); // calculate (end - start) / (steps - 1) Value sub; - if (op.getEnd().getType().isa() || - op.getStart().getType().isa()) { + if (isa(op.getEnd().getType()) || + isa(op.getStart().getType())) { sub = rewriter.create(loc, Torch::FloatType::get(context), op.getEnd(), op.getStart()); } else { @@ -6807,15 +9480,16 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { loc, sub, rewriter.create(loc, op.getSteps(), one)); // calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start Value mulScalar = - rewriter.create(loc, baseType, arange, div); - addStart = rewriter.create(loc, baseType, mulScalar, - op.getStart(), one); + rewriter.create(loc, arangeFp32Type, arange, div); + addStart = rewriter.create( + loc, arangeFp32Type, mulScalar, op.getStart(), one); } // to dtype Value result; - if (!op.getDtype().getType().isa()) { + if (!isa(op.getDtype().getType())) { result = rewriter.create( - loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal, + loc, op.getType(), addStart, op.getDtype(), + /*non_blocking=*/falseVal, /*copy=*/falseVal, /*memory_format=*/none); } else { Value f32Type = rewriter.create( @@ -6929,7 +9603,8 @@ class DecomposePrimsSqueezeOp : public OpRewritePattern { } result = *squeezeTensorInfo; } - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, op.getType(), + result); return success(); } }; @@ -7008,6 +9683,12 @@ class DecomposeAtenCrossEntropyLossOp return rewriter.notifyMatchFailure( op, "Unimplemented: unranked target tensor"); unsigned targetRank = maybeRank.value(); + Value reduction = op.getReduction(); + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } // When the input is 2-d i.e. of the form [minibatch, C] and target is 1-d // of the form [minibatch] the cross entropy loss decomposes to the @@ -7038,10 +9719,19 @@ class DecomposeAtenCrossEntropyLossOp loc, rewriter.getI64IntegerAttr(1)); Value logSoftmax = rewriter.create( loc, self.getType(), self, dim, /*dtype=*/noneVal); - Value nllLoss = + + Type secondType; + if (reductionInt == 0) { + secondType = target.getType(); + } else { + auto targetType = dyn_cast(target.getType()); + secondType = targetType.getWithSizesAndDtype({}, targetType.getDtype()); + } + + Value nllLoss = rewriter .create( - loc, op.getType(), target.getType(), logSoftmax, target, + loc, op.getType(), secondType, logSoftmax, target, op.getWeight(), op.getReduction(), op.getIgnoreIndex()) ->getResult(0); rewriter.replaceOp(op, nllLoss); @@ -7050,6 +9740,330 @@ class DecomposeAtenCrossEntropyLossOp }; } // namespace +namespace { +// Decompose aten::nll_loss_forward according to : +// torch/_decomp/decompositions.py and +// https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html. +// The (self, target) can be: +// 1. [N, C] and [C], +// or +// 2. [N] or []. +// The weight must be None or 1d where the numel must keep consistent with the +// number of classes. +class DecomposeAtenNllLossForwardOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNllLossForwardOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto ctx = op.getContext(); + + auto self = op.getSelf(); + auto target = op.getTarget(); + + auto selfType = dyn_cast(self.getType()); + auto targetType = dyn_cast(target.getType()); + + // constraints. + if (!selfType.hasSizes() || !targetType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "require self and target having sizes!"); + } + + if (!selfType.hasDtype() || !targetType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "require self and target having dtype!"); + } + + auto selfSizes = selfType.getSizes(); + auto targetSizes = targetType.getSizes(); + int64_t selfRank = selfSizes.size(); + int64_t targetRank = targetSizes.size(); + if (selfRank <= 0 || selfRank > 2) { + return rewriter.notifyMatchFailure(op, "input tensor should be 1D or 2D"); + } + if (targetRank > 1) { + return rewriter.notifyMatchFailure(op, + "target tensor shoule be 0D or 1D!"); + } + + if (selfRank != 1 || targetRank != 0) { + if (!(selfSizes[0] == kUnknownSize && targetSizes[0] == kUnknownSize) && + selfSizes[0] != targetSizes[0]) { + return rewriter.notifyMatchFailure( + op, + "input tensor and target tensor should have the same batch size!"); + } + } + + int64_t numClasses = selfSizes.back(); + auto weight = op.getWeight(); + auto weightT = weight.getType(); + if (!isa(weightT) && numClasses != kUnknownSize) { + auto weightType = dyn_cast(weightT); + if (weightType.areAllSizesKnown()) { + auto weightSizes = weightType.getSizes(); + int64_t weightNumel = 1; + for (size_t i = 0; i < weightSizes.size(); i++) { + weightNumel *= weightSizes[i]; + } + if (weightNumel != numClasses) { + return rewriter.notifyMatchFailure( + op, "weight tensor should be defined either for all classes or " + "no classes!"); + } + } + } + + Value reductionValue = op.getReduction(); + int64_t reduction; + if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } + + // decomposation. + uint64_t channelDim = 1; + if (selfRank < 2) { + channelDim = 0; + } + Value channelDimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(channelDim)); + + auto ignoreIndex = op.getIgnoreIndex(); + Value w; + if (!isa(weightT)) { + if (selfRank > 1) { + auto weightType = dyn_cast(weightT); + auto weightSizes = weightType.getSizes(); + SmallVector newShapeList(selfRank, 1); + newShapeList[channelDim] = weightSizes[0]; + SmallVector newShapeListValue; + for (size_t i = 0; i < newShapeList.size(); ++i) { + newShapeListValue.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(newShapeList[i]))); + } + Value newShape = rewriter.create( + loc, + rewriter.getType( + rewriter.getType()), + newShapeListValue); + auto newType = weightType.getWithSizesAndDtype(newShapeList, + weightType.getDtype()); + w = rewriter.create(loc, newType, weight, newShape); + } else { + w = weight; + } + + self = rewriter.create(loc, self.getType(), self, w); + } + + SmallVector targetDimSizes(targetSizes); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto condType = + ValueTensorType::get(ctx, targetDimSizes, rewriter.getI1Type()); + auto unequalCond = + rewriter.create(loc, condType, target, ignoreIndex); + auto zeroTensorType = + ValueTensorType::get(ctx, {}, rewriter.getIntegerType(64, true)); + Value zeroTensor = + rewriter.create(loc, zeroTensorType, zero); + auto safeTarget = rewriter.create( + loc, target.getType(), unequalCond, target, zeroTensor); + + SmallVector safeTargetShape; + for (size_t i = 0; i < targetSizes.size(); ++i) { + if (channelDim == i) { + safeTargetShape.push_back(1); + } + safeTargetShape.push_back(targetSizes[i]); + } + if (channelDim == safeTargetShape.size()) { + safeTargetShape.push_back(1); + } + + auto gatherType = + ValueTensorType::get(ctx, safeTargetShape, targetType.getDtype()); + auto safeTarget_ = rewriter.create( + loc, gatherType, safeTarget, channelDimValue); + auto falseValue = + rewriter.create(loc, rewriter.getBoolAttr(false)); + auto none = rewriter.create(loc); + auto _gather = rewriter.create( + loc, ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()), + self, channelDimValue, safeTarget_, falseValue); + Value gather = rewriter.create(loc, _gather.getType(), _gather); + auto unequalCondType = cast(unequalCond.getType()); + auto result = rewriter.create( + loc, + unequalCondType.getWithSizesAndDtype(unequalCondType.getSizes(), + selfType.getDtype()), + unequalCond, + rewriter.create( + loc, ValueTensorType::get(ctx, targetSizes, selfType.getDtype()), + gather, channelDimValue), + zeroTensor); + + Value totalWeight; + if (reduction == 0 && selfRank > 1) { + auto zeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value twSize = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + ValueRange({})); + + totalWeight = rewriter.create( + loc, op.getType(1), self, twSize, zeroFloat, none, none, none, none); + rewriter.replaceOp(op, {result, totalWeight}); + + return success(); + } + + if (!isa(weightT)) { + auto wType = cast(w.getType()); + auto newWType = wType.getWithSizesAndDtype(selfSizes, wType.getDtype()); + SmallVector selfSizesValue; + for (size_t i = 0; i < selfSizes.size(); ++i) { + selfSizesValue.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(selfSizes[i]))); + } + auto wSize = rewriter.create( + loc, + rewriter.getType(rewriter.getType()), + selfSizesValue); + w = rewriter.create(loc, newWType, w, wSize, falseValue); + auto wSumGather = rewriter.create( + loc, ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w, + channelDimValue, safeTarget_, falseValue); + auto wSumSq = rewriter.create( + loc, ValueTensorType::get(ctx, targetSizes, wType.getDtype()), + wSumGather, channelDimValue); + auto wSum = rewriter.create( + loc, + ValueTensorType::get(ctx, unequalCondType.getSizes(), + wType.getDtype()), + unequalCond, wSumSq, zeroTensor); + + totalWeight = rewriter.create(loc, op.getType(1), wSum, none); + } else { + totalWeight = + rewriter.create(loc, op.getType(1), unequalCond, none); + } + + auto resultSum = + rewriter.create(loc, op.getType(0), result, none); + if (reduction == 1) { + auto resultMean = rewriter.create( + loc, op.getType(0), resultSum, totalWeight); + rewriter.replaceOp(op, {resultMean, totalWeight}); + + return success(); + } + + rewriter.replaceOp(op, {resultSum, totalWeight}); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenBinaryCrossEntropyWithLogitsOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBinaryCrossEntropyWithLogitsOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto self = op.getSelf(); + auto target = op.getTarget(); + auto posWeight = op.getPosWeight(); + auto weight = op.getWeight(); + auto reduction = op.getReduction(); + + Value loss; + auto one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto _one = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + + auto _target = + rewriter.create(loc, target.getType(), target, _one); + auto _target_1 = rewriter.create(loc, _target.getType(), + _target, one, one); + Value mm = + rewriter.create(loc, self.getType(), _target_1, self); + Value logSigm = + rewriter.create(loc, self.getType(), self); + + if (!isa(posWeight.getType())) { + auto logWeight = rewriter.create( + loc, posWeight.getType(), + rewriter.create(loc, posWeight.getType(), posWeight, + one, one), + one, one); + loss = rewriter.create( + loc, mm.getType(), mm, + rewriter.create(loc, logWeight.getType(), logWeight, + logSigm), + one); + } else { + loss = + rewriter.create(loc, mm.getType(), mm, logSigm, one); + } + + if (!isa(weight.getType())) { + loss = + rewriter.create(loc, loss.getType(), loss, weight); + } + + // apply loss reduction. + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure(op, "no reduction type is appointed!"); + } + + auto none = rewriter.create(loc); + Value res; + if (reductionInt == 1) { + res = rewriter.create(loc, op.getType(), loss, none); + } else if (reductionInt == 2) { + res = rewriter.create(loc, op.getType(), loss, none); + } else { + res = loss; + } + + rewriter.replaceOp(op, res); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenExp2Op : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExp2Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result must have dtype"); + } + + auto two = + rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + Value to = convertTensorToDtype(rewriter, loc, self, resultTy.getDtype()); + Value pow = rewriter.create(loc, resultTy, two, to); + rewriter.replaceOp(op, pow); + return success(); + } +}; + +} // namespace + namespace { class DecomposeAtenOneHotOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -7064,10 +10078,9 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); int64_t inputRank = inputType.getSizes().size(); - int64_t numClasses; - if (!matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses))) - return rewriter.notifyMatchFailure( - op, "unimplemented: num_classes must be constant"); + int64_t numClasses = Torch::kUnknownSize; + auto resultType = cast(op.getType()); + matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)); Value none = rewriter.create(loc); // arange tensor @@ -7079,14 +10092,15 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { /*device=*/none, /*pin_memory=*/none); // unsqueeze input - llvm::SmallVector unsqueezeShape(inputType.getSizes()); - unsqueezeShape.push_back(1); - auto unsqueezeType = - ValueTensorType::get(context, unsqueezeShape, si64Type); - Value unsqueezeTensor = rewriter.create( - loc, unsqueezeType, input, - rewriter.create(loc, - rewriter.getI64IntegerAttr(inputRank))); + Value rankV = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank)); + auto unsqueeze = Torch::unsqueezeTensor(rewriter, op, input, rankV); + if (failed(unsqueeze)) + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + + Value unsqueezeTensor = + convertTensorToDtype(rewriter, loc, *unsqueeze, si64Type); // compare auto eqType = ValueTensorType::get( @@ -7096,7 +10110,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { loc, eqType, unsqueezeTensor, arangeTensor); // convert to si64 - Value result = convertTensorToDtype(rewriter, loc, eqTensor, si64Type); + Value result = + convertTensorToDtype(rewriter, loc, eqTensor, resultType.getDtype()); rewriter.replaceOp(op, result); return success(); } @@ -7199,6 +10214,221 @@ class DecomposeAtenTopkOp : public OpRewritePattern { }; } // namespace +namespace { + +/// Creates coefficients based on DFT definition, see +/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. +/// Even indices of the second dimension are for the real components of the +/// output. Odd indices for the imaginary components. +Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, + ValueTensorType matrixType) { + // scale = 2 * pi / N + double scale = 2 * M_PI / matrixType.getSizes()[0]; + + SmallVector values; + assert(matrixType.getSizes().size() == 2 && "expected 2D matrix"); + for (auto i : llvm::seq(0, matrixType.getSizes()[0])) { + for (auto j : llvm::seq(0, matrixType.getSizes()[1])) { + const bool isImagPart = j % 2; + double v = scale * i * (j / 2); + v = isImagPart ? -sin(v) : cos(v); + values.push_back(rewriter.getF32FloatAttr(v)); + } + } + + return rewriter.create( + loc, matrixType, + DenseElementsAttr::get(matrixType.toBuiltinTensor(), + ArrayRef(values))); +} + +class DecomposeAtenFftRfftOp final : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFftRfftOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + + int64_t dim; + auto dimVal = op.getDim(); + if (isa(dimVal.getType())) { + dim = -1; + } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constant"); + } + + if (!isa(op.getN().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); + } + + if (!isa(op.getNorm().getType())) { + return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); + } + + BaseTensorType inputType = cast(self.getType()); + + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "unsupported: only ranked tensors are supported"); + } + + const ArrayRef inputShape = inputType.getSizes(); + dim += dim < 0 ? inputShape.size() : 0; + + const int64_t fftLength = inputShape[dim]; + if (fftLength == kUnknownSize) { + return rewriter.notifyMatchFailure( + op, "unsupported: input signal length must be known"); + } + const int64_t rank = inputShape.size(); + const int64_t lastDim = rank - 1; + const int64_t outputFftDim = fftLength / 2 + 1; + const bool needTranspose = dim != lastDim; + + auto transposeValue = [](PatternRewriter &rewriter, Location loc, + Value input, int64_t dimA, int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(cast(input.getType()), dimA, + dimB, transposedType))) + return failure(); + Value cstDimA = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = + rewriter.create(loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create(loc, transposedType, + input, cstDimA, cstDimB); + return success(); + }; + + SmallVector lhsShape(inputShape); + // Transpose if FFT dimension is not the last one + if (needTranspose) { + if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) { + return failure(); + } + std::swap(lhsShape[dim], lhsShape[lastDim]); + } + // self : (D_0 x ... x D_m x fftLength) + + Type dtype = inputType.getOptionalDtype(); + + // coeff : (fftLength x outputFftDim*2) + ValueTensorType matrixType = ValueTensorType::get( + op.getContext(), SmallVector{fftLength, outputFftDim * 2}, + dtype); + Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType); + + // X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2) + SmallVector matmulShape(lhsShape.begin(), lhsShape.end() - 1); + matmulShape.push_back(outputFftDim * 2); + ValueTensorType matmulType = + ValueTensorType::get(op.getContext(), matmulShape, dtype); + Value flatRes = + rewriter.create(loc, matmulType, self, coeffMatrix); + + // Y = unflatten(X, -1, [outputFftDim, 2]) + // : (D_0 x ... x D_m x outputFftDim x 2) + // Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim) + SmallVector complexResShape(matmulShape); + complexResShape.back() = outputFftDim; + SmallVector unflattenedResShape(complexResShape); + unflattenedResShape.push_back(2); + Type unflattenedResType = + ValueTensorType::get(op.getContext(), unflattenedResShape, dtype); + Value cstMinusOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value unflattenSizes = toIntListConstruct(rewriter, loc, {outputFftDim, 2}); + Value unflattenedRes = rewriter.create( + loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes); + Type complexResType = ValueTensorType::get(op.getContext(), complexResShape, + ComplexType::get(dtype)); + Value complexRes = rewriter.create(loc, complexResType, + unflattenedRes); + + // Transpose back + if (needTranspose) { + if (failed(transposeValue(rewriter, loc, complexRes, dim, lastDim, + complexRes))) { + return failure(); + } + } + + rewriter.replaceOp(op, {complexRes}); + + return success(); + } +}; + +} // namespace + +namespace { +// Decompose `aten.hann_window` into `aten.arange.start`, `aten.mul.Scalar`, +// `aten.sin` and `aten.square` or into `aten.ones` in the trivial case +class DecomposeAtenHannWindowPeriodicOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHannWindowPeriodicOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + Type opType = op.getType(); + + Value opWindowLength = op.getWindowLength(); + Value opDtype = op.getDtype(); + Value opLayout = op.getLayout(); + Value opDevice = op.getDevice(); + Value opPinMemory = op.getPinMemory(); + + int64_t window_length; + if (!matchPattern(opWindowLength, m_TorchConstantInt(&window_length)) || + window_length <= 0) + return rewriter.notifyMatchFailure( + op, "Expected a constant integer greater than zero"); + bool periodic; + if (!matchPattern(op.getPeriodic(), m_TorchConstantBool(&periodic))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for periodic"); + + if (window_length == 1) { + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + SmallVector sizes({one}); + Value sizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), sizes); + rewriter.replaceOpWithNewOp(op, opType, sizeList, opDtype, + opLayout, opDevice, opPinMemory); + return success(); + } + + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + + Value arange = rewriter.create( + loc, opType, zero, op.getWindowLength(), opDtype, opLayout, opDevice, + opPinMemory); + + double denominator = !periodic ? window_length - 1 : window_length; + + double piOverDenominator = 3.14159 / denominator; + + Value cstFactor = rewriter.create( + loc, rewriter.getF64FloatAttr(piOverDenominator)); + + Value fraction = + rewriter.create(loc, opType, arange, cstFactor); + Value sine = rewriter.create(loc, opType, fraction); + + rewriter.replaceOpWithNewOp(op, opType, sine); + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.scatter.value` op into `aten.scatter.src` op. class DecomposeAtenScatterValueOp @@ -7229,11 +10459,8 @@ class DecomposeAtenScatterValueOp auto selfType = cast(self.getType()); auto indexType = cast(index.getType()); - BaseTensorType srcType = - selfType - .getWithSizesAndDtype(indexType.getOptionalSizes(), - selfType.getOptionalDtype()) - .cast(); + BaseTensorType srcType = cast(selfType.getWithSizesAndDtype( + indexType.getOptionalSizes(), selfType.getOptionalDtype())); Value src = createInitTensor(rewriter, loc, srcType, op.getValue(), sizeList); rewriter.replaceOpWithNewOp(op, op.getType(), self, @@ -7243,6 +10470,23 @@ class DecomposeAtenScatterValueOp }; } // namespace +namespace { +// Decompose prims.sum into aten.sum +class DecomposePrimsSumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsSumOp op, + PatternRewriter &rewriter) const override { + Value cstFalse = rewriter.create(op.getLoc(), false); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInp(), op.getDims(), /*keepdim=*/cstFalse, + op.getOutputDtype()); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.sgn` op into comparisons and aten.where. class DecomposeAtenSgnOp : public OpRewritePattern { @@ -7257,7 +10501,7 @@ class DecomposeAtenSgnOp : public OpRewritePattern { "expected result type to have dtype"); } // TODO: support complex type in future. - if (outType.getDtype().isa()) { + if (isa(outType.getDtype())) { return rewriter.notifyMatchFailure(op, "doesn't support complex type now"); } @@ -7320,6 +10564,31 @@ class DecomposeAtenTypeAsOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose aten.max_pool2d_with_indices +// into aten.max_pool2d +// when the second result is unused. +class DecomposeAtenMaxPool2dWithIndicesOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMaxPool2dWithIndicesOp op, + PatternRewriter &rewriter) const override { + if (!op->getResult(1).use_empty()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), op->getResult(0).getType(), op.getSelf(), + op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), + op.getCeilMode()); + + rewriter.replaceAllUsesWith(op->getResult(0), newOp); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + // Torch ops related to indexing tensors, e.g., AtenIndexTensor, AtenIndexPut. namespace { @@ -7373,7 +10642,7 @@ static FailureOr createNewIndices(Operation *op, Location loc = op->getLoc(); MLIRContext *context = op->getContext(); - auto inputType = input.getType().cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return failure(); } @@ -7382,7 +10651,7 @@ static FailureOr createNewIndices(Operation *op, int64_t maxIndexRank = 0; for (auto index : oldIndices) { - auto indexType = index.getType().dyn_cast(); + auto indexType = dyn_cast(index.getType()); if (!indexType) // None index continue; if (!indexType.hasSizes()) @@ -7471,15 +10740,13 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { int64_t inputRank = inputSizes.size(); auto isTensor = [](Value v) { - return v.getType().isa(); + return isa(v.getType()); }; // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { // By default, we regard the first index type as the list element type. - auto indexElemType = indices[0] - .getType() - .template cast() + auto indexElemType = cast(indices[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); auto newIndices = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); @@ -7569,7 +10836,7 @@ class DecomposeAtenIndexPutLikeOp "failed to get elements of `indices`"); auto input = op.getSelf(); - auto inputType = input.getType().template cast(); + auto inputType = cast(input.getType()); if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure( op, "only input with shape information is supported"); @@ -7578,15 +10845,13 @@ class DecomposeAtenIndexPutLikeOp int64_t inputRank = inputSizes.size(); auto isTensor = [](Value v) { - return v.getType().isa(); + return isa(v.getType()); }; // directly replace current op with aten.index_put.hacked_twin if (llvm::all_of(indices, isTensor)) { // By default, we regard the first index type as the list element type. - auto indexElemType = indices[0] - .getType() - .template cast() + auto indexElemType = cast(indices[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); auto newIndex = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); @@ -7716,7 +10981,7 @@ class DecomposeAtenLinalgNormOp : public OpRewritePattern { // default ord value is 2 for vector_norm auto ord = op.getOrd(); - if (ord.getType().isa()) { + if (isa(ord.getType())) { ord = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); } rewriter.replaceOpWithNewOp( @@ -7741,7 +11006,6 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp Value falseVal = rewriter.create(loc, false); Value one = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); // input/scale Value divScale = rewriter.create( @@ -7752,16 +11016,19 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp Value addZeroPoint = rewriter.create( loc, op.getType(), round, op.getZeroPoint(), one); // max(quant_min, std::nearby_int(input/scale) + zero_point) + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + auto tensorIntType = + ValueTensorType::get(context, ArrayRef{1}, si64Type); Value max = rewriter.create( loc, op.getType(), addZeroPoint, - rewriter.create(loc, baseType, op.getQuantMin(), + rewriter.create(loc, tensorIntType, op.getQuantMin(), /*dtype=*/none, /*device=*/none, /*requires_grad=*/falseVal)); // min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point)) Value min = rewriter.create( loc, op.getType(), max, - rewriter.create(loc, baseType, op.getQuantMax(), + rewriter.create(loc, tensorIntType, op.getQuantMax(), /*dtype=*/none, /*device=*/none, /*requires_grad=*/falseVal)); // min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point)) @@ -7778,6 +11045,532 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp }; } // namespace +namespace { +// Decompose aten.fake_quantize_per_tensor_affine_cachemask +// into aten.fake_quantize_per_tensor_affine +// when the second result is unused. +class DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp + : public OpRewritePattern { +public: + using OpRewritePattern< + AtenFakeQuantizePerTensorAffineCachemaskOp>::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFakeQuantizePerTensorAffineCachemaskOp op, + PatternRewriter &rewriter) const override { + if (!op->getResult(1).use_empty()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(), + op.getZeroPoint(), op.getQuantMin(), op.getQuantMax()); + + rewriter.replaceAllUsesWith(op->getResult(0), newOp); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +namespace { +// Decompose aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams +// into aten.fake_quantize_per_tensor_affine.tensor_qparams +// when the second result is unused. +class DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp + : public OpRewritePattern< + Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp> { +public: + using OpRewritePattern< + Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp>:: + OpRewritePattern; + LogicalResult + matchAndRewrite(Aten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp op, + PatternRewriter &rewriter) const override { + if (!op->getResult(1).use_empty()) + return failure(); + + auto newOp = + rewriter.create( + op.getLoc(), op->getResult(0).getType(), op.getSelf(), + op.getScale(), op.getZeroPoint(), op.getQuantMin(), + op.getQuantMax()); + + rewriter.replaceAllUsesWith(op->getResult(0), newOp); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +namespace { +// Decompose aten.fake_quantize_per_channel_affine_cachemask +// into aten.fake_quantize_per_channel_affine +// when the second result is unused. +class DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp + : public OpRewritePattern { +public: + using OpRewritePattern< + AtenFakeQuantizePerChannelAffineCachemaskOp>::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFakeQuantizePerChannelAffineCachemaskOp op, + PatternRewriter &rewriter) const override { + if (!op->getResult(1).use_empty()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(), + op.getZeroPoint(), op.getAxis(), op.getQuantMin(), op.getQuantMax()); + + rewriter.replaceAllUsesWith(op->getResult(0), newOp); + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +namespace { +// Decompose aten.fmax/fmin to aten.maximum/minimum + aten.where(nanMask) +template +class DecomposeAtenFMaxMinOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + BaseTensorType outType = cast(op.getType()); + Type nanMaskType = outType.getWithSizesAndDtype( + !outType.hasSizes() ? std::optional>() + : llvm::ArrayRef(outType.getSizes()), + rewriter.getI1Type()); + + Value self = op.getSelf(); + Value other = op.getOther(); + + Value normalResult = + rewriter.create(loc, outType, self, other).getResult(); + Value selfIsNan = + rewriter.create(loc, nanMaskType, self).getResult(); + Value otherIsNan = + rewriter.create(loc, nanMaskType, other) + .getResult(); + normalResult = rewriter.create( + loc, outType, otherIsNan, self, normalResult); + rewriter.replaceOpWithNewOp(op, outType, selfIsNan, other, + normalResult); + + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenThresholdOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenThresholdOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto selfType = dyn_cast(self.getType()); + if (!selfType || !selfType.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "requires input is tensor with sizes"); + } + + Value threshold = op.getThreshold(); + Value value = op.getValue(); + + auto comOp = rewriter.create( + loc, + selfType.getWithSizesAndDtype(selfType.getSizes(), + rewriter.getI1Type()), + self, threshold); + + rewriter.replaceOpWithNewOp(op, op.getType(), comOp, + self, value); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenFloatPowerTensorTensorOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFloatPowerTensorTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value exp = op.getExponent(); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasDtype() || !selfTy.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "requires input is tensor with dtype and sizes"); + } + + Value selfF64 = + convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type()); + rewriter.replaceOpWithNewOp(op, op.getType(), + selfF64, exp); + + return success(); + } +}; +} // namespace + +namespace { +class DecomposeTorchvisionNmsOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TorchvisionNmsOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value boxes = op.getDets(); + Value scores = op.getScores(); + Value iouThreshold = op.getIouThreshold(); + + Value cst0 = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value cst2 = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value cst4 = rewriter.create( + loc, rewriter.getI64IntegerAttr(4)); + Value cstNone = rewriter.create(loc); + Value cstTrue = + rewriter.create(loc, rewriter.getBoolAttr(true)); + Value cstFalse = rewriter.create( + loc, rewriter.getBoolAttr(false)); + + // Get number of boxes for the loop count + auto boxesTensorType = dyn_cast(boxes.getType()); + auto dType = boxesTensorType.getDtype(); + int64_t boxesSize = boxesTensorType.getSizes()[0]; + Value len = rewriter.create(loc, boxes, /*dim=*/cst0); + + // Calculate the area of each box: (x2 - x1) * (y2 - y1) + auto sliceTy = rewriter.getType( + SmallVector{boxesSize, 2}, dType); + Value lowSlice = rewriter.create( + loc, sliceTy, boxes, + /*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1); + Value highSlice = rewriter.create( + loc, sliceTy, boxes, + /*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1); + Value distance = rewriter.create( + loc, sliceTy, highSlice, lowSlice, cst1); + auto areaTy = rewriter.getType( + SmallVector{boxesSize}, dType); + Value area = rewriter.create( + loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + + // Sort scores in descending order + // Use the sorted indices to iterate boxes + auto scoresType = dyn_cast(scores.getType()); + auto intTensorType = scoresType.getWithSizesAndDtype( + scoresType.getOptionalSizes(), + IntegerType::get(context, 64, IntegerType::Signed)); + auto sortResult = rewriter.create( + loc, TypeRange({scores.getType(), intTensorType}), scores, + /*dim=*/cst0, /*descending=*/cstTrue); + + // Create a mask to mark if we keep the boxes + Value lenShapeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + SmallVector{len}); + Value mask = rewriter.create( + loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone); + Value zeroShapeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + SmallVector{cst1}); + auto zeroTy = rewriter.getType( + SmallVector{1}, rewriter.getIntegerType(64, /*signed=*/true)); + Value falseMask = rewriter.create( + loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone); + + // Create an empty tensor for result + Value result = rewriter.create( + loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); + + auto intTy = rewriter.getType(); + auto rowSliceTy = + rewriter.getType(SmallVector{1, 4}, dType); + auto pointTy = + rewriter.getType(SmallVector{1, 2}, dType); + auto extractTy = rewriter.getType( + SmallVector{1}, rewriter.getIntegerType(64, true)); + Value float0 = rewriter.create( + loc, rewriter.getFloatAttr(dType, 0.0)); + auto scalarFloatType = rewriter.getType( + SmallVector{1}, dType); + Value float0Tensor = rewriter.create( + loc, scalarFloatType, float0); + + // 1. Loop through the boxes based on sorted indices + // 2. Add the current box to result if it's not suppressed + // 3. Calculate the IoUs with all boxes + // 4. Loop through the rest boxes in sorted indices + // 5. Suppress the box if the corresponding IoU is larger than threshold + auto loop1 = rewriter.create( + loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue, + ValueRange({mask, result, cst0})); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *loopBody1 = rewriter.createBlock( + &loop1.getRegion(), loop1.getRegion().begin(), + TypeRange({intTy, intTensorType, intTensorType, intTy}), + {loc, loc, loc, loc}); + Value i = loopBody1->getArgument(0); + Value mask1 = loopBody1->getArgument(1); + Value curResult = loopBody1->getArgument(2); + Value curCnt = loopBody1->getArgument(3); + + // Extract the mask to check if the base box is suppressed + Value extract = rewriter.create( + loc, extractTy, mask1, /*dim=*/cst0, /*index=*/i); + Value scalar = rewriter.create(loc, intTy, extract); + Value iskept = rewriter.create( + loc, rewriter.getType(), scalar); + auto ifFilterOthers = rewriter.create( + loc, TypeRange({intTensorType, intTensorType, intTy}), iskept); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifFilterOthers.getThenRegion(), + ifFilterOthers.getThenRegion().begin()); + + // Scatter the selected indices into result + Value extractIdx1 = rewriter.create( + loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, + /*index=*/i); + Value next = rewriter.create(loc, curCnt, cst1); + Value updatedResult = rewriter.create( + loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0, + /*start=*/curCnt, /*end=*/next, /*step=*/cst1); + + // Get the coordinates of base box + Value idx1 = + rewriter.create(loc, intTy, extractIdx1); + Value idx1End = rewriter.create(loc, idx1, cst1); + Value curBox = rewriter.create( + loc, rowSliceTy, boxes, + /*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1); + + // Calculate IoUs: intersectionArea / unionArea + // Intersection area = intersectionWidth * intersectionHeight + Value point1 = rewriter.create( + loc, pointTy, curBox, + /*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1); + Value point2 = rewriter.create( + loc, pointTy, curBox, + /*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1); + Value innerLow = rewriter.create( + loc, sliceTy, lowSlice, point1); + Value innerHigh = rewriter.create( + loc, sliceTy, highSlice, point2); + Value innerDistance = rewriter.create( + loc, sliceTy, innerHigh, innerLow, cst1); + innerDistance = rewriter.create( + loc, sliceTy, innerDistance, float0Tensor); + Value intersectionArea = rewriter.create( + loc, areaTy, innerDistance, /*dim=*/cst1, /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + Value iEnd = rewriter.create(loc, i, cst1); + Value curArea = rewriter.create( + loc, scalarFloatType, area, + /*dim=*/cst0, /*start=*/i, /*end=*/iEnd, /*step=*/cst1); + // Union area = area1 + area2 - intersectionArea + Value unionArea = rewriter.create( + loc, areaTy, area, curArea, cst1); + unionArea = rewriter.create( + loc, areaTy, unionArea, intersectionArea, cst1); + Value iou = rewriter.create( + loc, areaTy, intersectionArea, unionArea); + + // Loop through the rest of boxes in sorted indices + auto loop2 = rewriter.create(loc, intTensorType, len, + cstTrue, mask1); + { + PatternRewriter::InsertionGuard guard(rewriter); + Block *loopBody2 = rewriter.createBlock( + &loop2.getRegion(), loop2.getRegion().begin(), + TypeRange({intTy, intTensorType}), {loc, loc}); + Value j = loopBody2->getArgument(0); + Value mask2 = loopBody2->getArgument(1); + + // Check if current index is out of range + j = rewriter.create(loc, j, i); + j = rewriter.create(loc, j, cst1); + Value isInRange = rewriter.create(loc, j, len); + auto ifCalculateIou = rewriter.create( + loc, TypeRange({intTensorType}), isInRange); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifCalculateIou.getThenRegion(), + ifCalculateIou.getThenRegion().begin()); + + // Retrieve IoU and check if suppress the box + Value extractIdx2 = rewriter.create( + loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, + /*index=*/j); + Value idx2 = + rewriter.create(loc, intTy, extractIdx2); + Value idx2End = + rewriter.create(loc, idx2, cst1); + Value curIoU = rewriter.create( + loc, scalarFloatType, iou, + /*dim=*/cst0, /*start=*/idx2, /*end=*/idx2End, /*step=*/cst1); + curIoU = rewriter.create( + loc, rewriter.getType(), curIoU); + Value isSuppressed = rewriter.create( + loc, curIoU, iouThreshold); + + auto ifUnmask = rewriter.create( + loc, TypeRange({intTensorType}), isSuppressed); + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifUnmask.getThenRegion(), + ifUnmask.getThenRegion().begin()); + + // Update the mask if suppress + Value jEnd = rewriter.create(loc, j, cst1); + Value updatedMask = rewriter.create( + loc, intTensorType, mask2, falseMask, /*dim=*/cst0, + /*start=*/j, /*end=*/jEnd, /*step=*/cst1); + rewriter.create(loc, updatedMask); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifUnmask.getElseRegion(), + ifUnmask.getElseRegion().begin()); + rewriter.create(loc, mask2); + } + + rewriter.create(loc, ifUnmask.getResult(0)); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifCalculateIou.getElseRegion(), + ifCalculateIou.getElseRegion().begin()); + rewriter.create(loc, mask2); + } + + rewriter.create( + loc, cstTrue, ifCalculateIou.getResult(0)); + } + + rewriter.create( + loc, ValueRange({loop2.getResult(0), updatedResult, next})); + } + { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.createBlock(&ifFilterOthers.getElseRegion(), + ifFilterOthers.getElseRegion().begin()); + rewriter.create( + loc, ValueRange({mask1, curResult, curCnt})); + } + + rewriter.create(loc, cstTrue, + ifFilterOthers.getResults()); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), loop1.getResult(1), /*dim=*/cst0, /*start=*/cst0, + /*end=*/loop1.getResult(2), /*step=*/cst1); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenSpecialExpm1Op + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSpecialExpm1Op op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenConstrainRangeForSizeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSymConstrainRangeForSizeOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto min = op.getMin(); + auto max = op.getMax(); + + int64_t minValue, maxValue; + + if (isa(min.getType())) { + // Set min value to 0 + min = rewriter.create(loc, 0); + } else { + // Check if min value is a constant + if (!matchPattern(min, m_TorchConstantInt(&minValue))) + return rewriter.notifyMatchFailure( + op, "Expected min value to be constant integer"); + } + + if (!isa(max.getType())) { + // Verify that max value is greater than 2 + if (!matchPattern(max, m_TorchConstantInt(&maxValue))) + return rewriter.notifyMatchFailure( + op, "Expected max value to be constant integer"); + + if (maxValue <= 2) { + std::string errorMsg = "Max value to constrain_range_for_size must be " + "greater than 2, got: " + + std::to_string(maxValue); + return op.emitError(errorMsg); + } + } + + rewriter.replaceOpWithNewOp(op, op.getSize(), min, + max); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAten_AssertScalarOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_AssertScalarOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto assertCond = op.getSelf(); + + if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + else if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + assert(isa(assertCond.getType()) && + "Unhandled type encountered in aten._assert_scalar op"); + + std::string assertMessage; + if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) + return rewriter.notifyMatchFailure( + op, "Assert message must be a constant string"); + + rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -7814,8 +11607,11 @@ class DecomposeComplexOpsPass legalOpsSet.clear(); legalOpsSet.insert(legalOps.begin(), legalOps.end()); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7827,6 +11623,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( @@ -7839,18 +11637,21 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7872,15 +11673,22 @@ class DecomposeComplexOpsPass DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAminAmaxOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAminAmaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7905,8 +11713,17 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7939,24 +11756,30 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7964,6 +11787,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7984,36 +11808,80 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp>( + patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp>(patterns); // More specific conv ops addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvPaddingOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvPaddingOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenConvPaddingOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); + + addPatternIfTargetOpIsIllegal< + DecomposeAtenFMaxMinOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenFMaxMinOp>(patterns); + + // Torchvision ops + addPatternIfTargetOpIsIllegal(patterns); + + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp index b5dcbbf584ee..db80714127e1 100644 --- a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp +++ b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp @@ -9,13 +9,9 @@ #include "PassDetail.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 7870ff63cb40..da06e1c59a75 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -39,9 +39,16 @@ template <> struct QuantInfo { bool isQCommutingOp(mlir::Operation *op) { // if adding a new commuting op here, be sure to add a // RemoveUnused pattern for that op to clean up afterwards - return llvm::isa(op); + return llvm::isa( + op); } +struct QuantizedChain { + std::stack commutingOpStack; + Value dequantOpd, MPTQTOpd, scale, zeroPoint; +}; + // The following conversion takes patterns of the form [op0 -> MPTQT -> dequant // -> Op1 -> Op2 -> ... Opk -> SrcOp] to [op0 -> Int(Op1) -> Int(Op2) -> ... -> // Int(Opk) -> MPTQT -> SrcOp] for any sequence of q commuting ops @@ -56,15 +63,22 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern { LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const override { - mlir::Location loc = op.getLoc(); llvm::SmallVector operands(op->getOperands()); - bool dequanted = false; + // Prevent fusion for 1d convolution ops and just do it as an f32 conv since + // there isn't a linalg named op for quantized 1-d convolution yet. + // TODO: Remove this and add support for 1-d quantized convolution. + int64_t inputRank = + cast(operands[0].getType()).getSizes().size(); + if (isa(op) && inputRank < 4) + return rewriter.notifyMatchFailure( + op, "1-d quantized convolution is not supported"); + + SmallVector operandChains; for (unsigned i : QuantInfo::operandsToQuantize) { Value operand = operands[i]; - std::stack commutingOpStack; - Value dequantOpd, MPTQTOpd; + QuantizedChain chain; for (unsigned k = 0; k < depth + 1; k++) { auto currOp = operand.getDefiningOp(); // Case 0 : currOp is a nullptr (e.g., operand is a block argument) @@ -72,40 +86,103 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern { break; // Case 1 : currOp is a q commuting op (continue loop) if (isQCommutingOp(currOp)) { - commutingOpStack.push(currOp); + chain.commutingOpStack.push(currOp); // set operand to currOp for next k-iteration operand = currOp->getOperand(0); continue; } // Case 2 : currOp is a dequant op (end loop) if (llvm::isa(currOp)) { - dequantOpd = currOp->getOperand(0); + chain.dequantOpd = currOp->getOperand(0); + // Bail out if any operand is per-channel quantized, which would + // require more complex fusion logic. + if (llvm::isa( + chain.dequantOpd.getDefiningOp())) + break; + auto MPTQTOp = - dequantOpd.getDefiningOp(); - MPTQTOpd = MPTQTOp.getOperand(0); + chain.dequantOpd + .getDefiningOp(); + chain.MPTQTOpd = MPTQTOp.getOperand(0); + chain.scale = MPTQTOp.getOperand(1); + chain.zeroPoint = MPTQTOp.getOperand(2); } // either a dequant was found or chain broken, so break loop break; } - // move to next operand if this trace was unsuccessful - if (!MPTQTOpd) - continue; + // if tracing this operand was successful, add it to operandChains. + if (chain.MPTQTOpd) + operandChains.push_back(std::move(chain)); + } - // a successful trace occured, so set dequant to true - dequanted = true; + // Continuing the rewriting with only some of the operandsToQuantize traced + // successfully is possible but leads to "half-quantized" ops which are + // expected to cause problems in later lowering steps. We opt out of + // treating these cases for now. + if (operandChains.size() != + std::size(QuantInfo::operandsToQuantize)) { + if (!operandChains.empty()) + op.emitWarning("Partially traced quantized operands. This op will " + "remain in QDQ form."); + return rewriter.notifyMatchFailure( + op, "did not find a complete quantized chain for all operands"); + } + for (auto &&[i, chain] : llvm::enumerate(operandChains)) { // rewrite stack - Value oldOpd = MPTQTOpd; + Value oldOpd = chain.MPTQTOpd; Type intDType = - cast(MPTQTOpd.getType()).getOptionalDtype(); - while (!commutingOpStack.empty()) { + cast(chain.MPTQTOpd.getType()).getOptionalDtype(); + while (!chain.commutingOpStack.empty()) { // get front of the commuting op stack and replace its first operand // with oldOpd - auto currOp = commutingOpStack.top(); - commutingOpStack.pop(); + auto currOp = chain.commutingOpStack.top(); + chain.commutingOpStack.pop(); llvm::SmallVector currOperands(currOp->getOperands()); currOperands[0] = oldOpd; + // pad ops aren't quite commuting, so we include some extra logic to + // quantize the padding value + if (isa(currOp)) { + Value floatPadValue = currOperands.back(); + Value quantPadValue; + if (isa(floatPadValue.getType())) + quantPadValue = + rewriter.create(loc, chain.zeroPoint); + else { + floatPadValue = + rewriter.create(loc, floatPadValue); + quantPadValue = rewriter.create( + loc, floatPadValue, chain.scale); + quantPadValue = rewriter.create( + loc, quantPadValue, chain.zeroPoint); + } + // clamp pad value to qint range + if (auto intType = dyn_cast(intDType)) { + bool isSigned = intType.isSignedInteger(); + int64_t width = intType.getWidth(); + assert(width < 64 && + "quantized int bitwidth should be less than 64"); + int64_t minInt = isSigned ? -(1 << (width - 1)) : 0; + int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1); + Value minQValueFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(minInt)); + Value maxQValueFloat = rewriter.create( + loc, rewriter.getF64FloatAttr(maxInt)); + SmallVector emptyShape; + auto floatTensorType = rewriter.getType( + emptyShape, rewriter.getF64Type()); + Value quantPadValueTensor = createRank0Tensor( + rewriter, loc, floatTensorType, quantPadValue); + Value clampedTensor = rewriter.create( + loc, floatTensorType, quantPadValueTensor, minQValueFloat, + maxQValueFloat); + quantPadValue = rewriter.create( + loc, rewriter.getType(), clampedTensor); + } + // quantPadValue is a float, but will get converted/truncated + currOperands.back() = quantPadValue; + } // get new result type auto oldType = cast(currOp->getResultTypes()[0]); auto intType = @@ -121,19 +198,15 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern { // stack is empty, so oldOpd is now the corrected verion of the // SrcOp's original operand // convert operand -> SrcOp to oldOpd -> newMPTQTOp -> SrcOp - auto MPTQTOperands = dequantOpd.getDefiningOp()->getOperands(); + auto MPTQTOperands = chain.dequantOpd.getDefiningOp()->getOperands(); auto qTorchType = - cast(dequantOpd.getType()).getOptionalDtype(); + cast(chain.dequantOpd.getType()).getOptionalDtype(); auto newMPTQTType = rewriter.getType( cast(operands[i].getType()).getSizes(), qTorchType); operands[i] = rewriter.create( loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); } - if (!dequanted) { - return rewriter.notifyMatchFailure(op, "No dequantizations found."); - } - rewriter.replaceOpWithNewOp(op, op.getType(), operands); return success(); } @@ -372,18 +445,20 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, RemoveUnused, RemoveUnused, - RemoveUnused, - QuantizeOperandsPastCommutingOps, + RemoveUnused, RemoveUnused, + RemoveUnused, RemoveUnused, + RemoveUnused, + QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, QuantizeOperandsPastCommutingOps, - QuantizeOperandsPastCommutingOps, + QuantizeOperandsPastCommutingOps, QuantizeAccumulator, QuantizeAccumulator, QuantizeResultLikeOperand, QuantizeBias>( context); GreedyRewriteConfig config; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index dbf203584601..ff55081a6e67 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -9,18 +9,14 @@ #include "PassDetail.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringSet.h" using namespace mlir; using namespace mlir::torch; diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 2aa9f42307b1..9c8936c8bffa 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -29,7 +29,6 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "llvm/Support/Debug.h" @@ -50,22 +49,21 @@ using namespace mlir::torch::Torch; /// a single module. If we had to support complex nested symbol references, we /// would probably want to go through the effort to indirect through the symbol /// tables to make things clearer. -class FlatSymbolRefProgramPoint - : public GenericProgramPointBase { +class FlatSymbolRefLatticeAnchor + : public GenericLatticeAnchorBase { public: using Base::Base; void print(raw_ostream &os) const override { - os << "FlatSymbolRefProgramPoint(" << getValue() << ")"; + os << "FlatSymbolRefLatticeAnchor(" << getValue() << ")"; } Location getLoc() const override { - return UnknownLoc::get(getValue().getContext()); + return UnknownLoc::get(getValue()->getContext()); } }; static bool isTypeTriviallySafe(Type type) { - return type.isa(); + return isa(type); } static bool isUseTreatedWithValueSemantics(OpOperand &use) { @@ -85,7 +83,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) { /// State tracking if an IR construct is "safe". /// /// This state is tracked on Value's and also on global slots (via a -/// FlatSymbolRefProgramPoint). +/// FlatSymbolRefLatticeAnchor). /// /// In this context, "safe" means that the object is safe to inline. /// This covers a few concepts @@ -94,7 +92,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) { /// unsafe class InlineGlobalSlotsAnalysisState : public AnalysisState { public: - InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) { + InlineGlobalSlotsAnalysisState(LatticeAnchor point) : AnalysisState(point) { (void)setSafe(); } @@ -134,7 +132,7 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis { public: InlineGlobalSlotsAnalysis(DataFlowSolver &solver); LogicalResult initialize(Operation *top) override; - LogicalResult visit(ProgramPoint point) override; + LogicalResult visit(ProgramPoint *point) override; private: /// The local transfer function determining the safety of `value`. @@ -148,33 +146,33 @@ class InlineGlobalSlotsAnalysis : public DataFlowAnalysis { InlineGlobalSlotsAnalysis::InlineGlobalSlotsAnalysis(DataFlowSolver &solver) : DataFlowAnalysis(solver) { - registerPointKind(); + registerAnchorKind(); } LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { auto walkResult = top->walk([this](Operation *op) { if (auto globalSlot = dyn_cast(op)) { auto *state = getOrCreate( - getProgramPoint( - FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()))); + getLatticeAnchor(globalSlot)); propagateIfChanged(state, state->setSafe(globalSlot.getVisibility() != SymbolTable::Visibility::Public)); } if (auto globalSlotSet = dyn_cast(op)) { + auto globalSlot = SymbolTable::lookupNearestSymbolFrom( + globalSlotSet, globalSlotSet.getSlotAttr()); + auto *state = getOrCreate( - getProgramPoint( - globalSlotSet.getSlotAttr())); + getLatticeAnchor(globalSlot)); propagateIfChanged(state, state->setSafe(false)); } // Save the InitializeGlobalSlotsOp for later referencee if (auto initialize = dyn_cast(op)) { initializeGlobalSlotsOp = initialize; } - for (Value result : op->getResults()) { - if (failed(visit(result))) - return WalkResult::interrupt(); - } + if (failed(visit(getProgramPointAfter(op)))) + return WalkResult::interrupt(); + return WalkResult::advance(); }); if (walkResult.wasInterrupted()) @@ -182,51 +180,36 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { return success(); } -LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { - if (Value value = dyn_cast(point)) { - bool isSafe = isValueSafeTransferFunction(value); - auto *state = getOrCreate(value); - propagateIfChanged(state, state->setSafe(isSafe)); - - // Handle GlobalSlotGetOp's. - if (auto opResult = dyn_cast(value)) { - if (auto globalSlotGet = - dyn_cast(opResult.getOwner())) { - auto *flatSymbolRefPoint = getProgramPoint( - globalSlotGet.getSlotAttr()); - auto *valueState = getOrCreateFor( - flatSymbolRefPoint, globalSlotGet.getResult()); - auto *globalState = - getOrCreate(flatSymbolRefPoint); - propagateIfChanged(globalState, - globalState->incorporateSafetyOfUse(valueState)); - } - } - +LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint *point) { + if (point->isBlockStart()) return success(); - } - if (auto *genericProgramPoint = dyn_cast(point)) { - if (auto *flatSymbolRefPoint = - dyn_cast(genericProgramPoint)) { - if (initializeGlobalSlotsOp) { - auto it = - llvm::find(initializeGlobalSlotsOp.getSlotSymNames(), - static_cast(flatSymbolRefPoint->getValue())); - Value value = initializeGlobalSlotsOp->getOperand(std::distance( - initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); - auto *flatSymbolRefState = - getOrCreateFor(value, - flatSymbolRefPoint); - auto *valueState = getOrCreate(value); - propagateIfChanged(valueState, - valueState->setSafe(flatSymbolRefState->isSafe)); + + if (auto op = point->getPrevOp()) { + for (auto value : op->getResults()) { + bool isSafe = isValueSafeTransferFunction(value); + auto *state = getOrCreate(value); + propagateIfChanged(state, state->setSafe(isSafe)); + + // Handle GlobalSlotGetOp's. + if (auto opResult = dyn_cast(value)) { + if (auto globalSlotGet = + dyn_cast(opResult.getOwner())) { + auto globalSlot = SymbolTable::lookupNearestSymbolFrom( + globalSlotGet, globalSlotGet.getSlotAttr()); + auto *flatSymbolRefPoint = + getLatticeAnchor(globalSlot); + auto *valueState = getOrCreateFor( + getProgramPointAfter(globalSlot), globalSlotGet.getResult()); + auto *globalState = + getOrCreate(flatSymbolRefPoint); + propagateIfChanged(globalState, + globalState->incorporateSafetyOfUse(valueState)); + } } - return success(); } } - LLVM_DEBUG( - { llvm::dbgs() << "visit failing because of: " << point << "\n"; }); - return failure(); + + return success(); } // This is only a member function to access protected get* functions. @@ -242,16 +225,20 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) { // safe. This covers, for example, view-like ops that create aliases. if ((op->hasTrait() || isMemoryEffectFree(op)) && llvm::all_of(op->getResults(), [&](Value result) { - auto *state = - getOrCreateFor(value, result); + auto *state = getOrCreateFor( + getProgramPointAfter(value.getDefiningOp()), result); return state->isSafe; })) continue; if (auto initialize = dyn_cast(op)) { auto symName = cast( initialize.getSlotSymNames()[use.getOperandNumber()]); + auto globalSlot = + SymbolTable::lookupNearestSymbolFrom(op, symName); + auto *state = getOrCreateFor( - value, getProgramPoint(symName)); + getProgramPointAfter(value.getDefiningOp()), + getLatticeAnchor(globalSlot)); if (state->isSafe) continue; } @@ -300,8 +287,7 @@ class InlineGlobalSlotsPass module->walk([&](Operation *op) { if (auto globalSlot = dyn_cast(op)) { auto *state = solver.lookupState( - solver.getProgramPoint( - FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()))); + solver.getLatticeAnchor(globalSlot)); state->print(llvm::dbgs()); llvm::dbgs() << ": " << FlatSymbolRefAttr::get(globalSlot.getSymNameAttr()) @@ -335,13 +321,16 @@ class InlineGlobalSlotsPass auto slotSymName = cast(initialize.getSlotSymNames()[i]); Value operand = initialize.getOperand(i); - auto symbolRefPoint = solver.getProgramPoint( - cast(initialize.getSlotSymNames()[i])); + auto globalSlot = SymbolTable::lookupNearestSymbolFrom( + initialize, slotSymName); + + auto symbolRefPoint = + solver.getLatticeAnchor(globalSlot); auto *state = solver.lookupState(symbolRefPoint); // We roll the analysis of whether a slot is set or public into the // main dataflow analysis, so we need to check the slot's - // FlatSymbolRefProgramPoint itself to see if it is safe to inline. + // FlatSymbolRefLatticeAnchor itself to see if it is safe to inline. // For example, a public !torch.int is not safe to inline, even though // it is a value-semantic type and so the actual initializer value // itself is conceptually safe to inline. diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index bda2d258aba3..f15911e2b5ba 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -12,7 +12,6 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -37,8 +36,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, static LogicalResult checkType(Operation *op, Type type, bool actuallyEmitDiagnostics) { // Allow various scalar types that backends are expected to be able to handle. - if (type.isa()) + if (isa( + type)) return success(); // Backends are not expected to support dynamic computations on these types, @@ -372,6 +371,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, llvm::StringSet<> backendLegalOpsSet) { target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -381,6 +381,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -390,11 +392,16 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -402,7 +409,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -416,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, }); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -428,12 +440,17 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -457,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -481,6 +499,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -496,7 +518,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -505,7 +526,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -518,8 +541,13 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -529,13 +557,21 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index c237ede12479..0e3cda033a18 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -21,10 +21,12 @@ using namespace mlir::torch::Torch; namespace { Type getQuantizedType(MLIRContext *context, Type t) { - if (t.isSignlessInteger(8)) + if (t.isSignlessInteger(8) || t.isUnsignedInteger(8)) return Torch::QUInt8Type::get(context); if (t.isInteger(8) || t.isSignedInteger(8)) return Torch::QInt8Type::get(context); + if (t.isInteger(16)) + return Torch::QInt16Type::get(context); if (t.isInteger(32)) return Torch::QInt32Type::get(context); return {}; @@ -57,10 +59,11 @@ class MatchQuantizeOperator : public OpRewritePattern { return success(); } - if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") { - auto clamp = rewriter.create( - op.getLoc(), op.getOperand(0).getType(), op.getOperand(0), - op.getOperand(3), op.getOperand(4)); + auto prepareDequantize = [&](Value quantMin, Value quantMax, Value &clamp, + Type &qTy) { + clamp = + rewriter.create(op.getLoc(), op.getOperand(0).getType(), + op.getOperand(0), quantMin, quantMax); auto clampTy = cast(clamp.getType()); if (!clampTy.hasDtype()) @@ -73,8 +76,18 @@ class MatchQuantizeOperator : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "dequantization has unknown qtype"); - Type qTy = Torch::ValueTensorType::get( - op.getContext(), clampTy.getOptionalSizes(), qetype); + qTy = Torch::ValueTensorType::get(op.getContext(), + clampTy.getOptionalSizes(), qetype); + return success(); + }; + + if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") { + Value clamp; + Type qTy; + if (failed(prepareDequantize(op.getOperand(3), op.getOperand(4), clamp, + qTy))) + return failure(); + auto quant = rewriter.create( op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2)); rewriter.replaceOpWithNewOp( @@ -82,6 +95,20 @@ class MatchQuantizeOperator : public OpRewritePattern { return success(); } + if (op.getName() == "torch.quantized_decomposed.dequantize_per_channel") { + Value clamp; + Type qTy; + if (failed(prepareDequantize(op.getOperand(4), op.getOperand(5), clamp, + qTy))) + return failure(); + auto quant = rewriter.create( + op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2), + op.getOperand(3)); + rewriter.replaceOpWithNewOp(op, op.getResultTypes(), + quant); + return success(); + } + return failure(); } }; @@ -95,8 +122,8 @@ class MatchQuantizedCustomOpsPass patterns.insert(context); GreedyRewriteConfig config; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) return signalPassFailure(); } }; diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index cd4b74be678e..10580b81876b 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -9,8 +9,6 @@ #include "PassDetail.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -189,7 +187,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock auto it = originalReturnTypes.find(i); if (it == originalReturnTypes.end()) continue; - auto originalType = it->second.cast(); + auto originalType = cast(it->second); rewriter.setInsertionPoint(returnOp); Value newReturnValue = copyTensorToType(rewriter, returnOp->getLoc(), originalType, operand.get()); @@ -352,7 +350,7 @@ class RewriteViewLikeSubgraph auto it = originalTypes.find(operand.get()); if (it == originalTypes.end()) continue; - auto originalType = it->second.cast(); + auto originalType = cast(it->second); rewriter.setInsertionPoint(op); Value newReturnValue = copyTensorToType(rewriter, op->getLoc(), originalType, operand.get()); @@ -374,7 +372,7 @@ class MaximizeValueSemanticsPass RewritePatternSet patterns(context); patterns.insert(context); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } }; diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index d01eac967b22..846470202c15 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Passes.h" void mlir::torch::registerTorchPasses() { mlir::torch::registerPasses(); @@ -17,10 +18,18 @@ void mlir::torch::registerTorchPasses() { "torchscript-module-to-torch-backend-pipeline", "Pipeline lowering TorchScript object graph IR to Torch backend form.", mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline); + mlir::PassPipelineRegistration( + "torchdynamo-export-to-torch-backend-pipeline", + "Pipeline lowering TorchDynamo exported graph IR to Torch backend form.", + mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline); mlir::PassPipelineRegistration( "torch-function-to-torch-backend-pipeline", "Pipeline lowering a Torch function to Torch backend form.", mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline); + mlir::PassPipelineRegistration( + "torch-onnx-to-torch-backend-pipeline", + "Pipeline lowering Torch Onnx IR to Torch backend form.", + mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline); mlir::PassPipelineRegistration( "torch-simplification-pipeline", "Pipeline simplifying computations in the program.", @@ -59,6 +68,18 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( createTorchFunctionToTorchBackendPipeline(pm, options); } +void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + pm.addNestedPass( + createReduceOpVariantsPass(options.extraLibrary)); + pm.addNestedPass(createCanonicalizerPass()); + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } +} + void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { // Incorporate user annotations and remove signature Python-isms. @@ -70,6 +91,37 @@ void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( options.backendLegalOps, options.extraLibrary)); } +void mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + pm.addNestedPass(onnx_c::createTorchOnnxToTorchPass()); + // The above pass just converts the torch onnx IR to torch, hence the given + // pipeline will make sure that the IR is transformed such that it satisfies + // the backend contract. + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } + // TODO: Move the combination of two passes i.e., ScalarizeShapes and + // TorchShapeRefinementPipeline out of here and create an onnx shape + // refinement pipeline which runs iteratively over the IR. + createTorchShapeRefinementPipeline(pm, options); + // This pass scalarizes the tensor shape computations. + pm.addNestedPass( + mlir::torch::Torch::createScalarizeShapesPass()); + createTorchShapeRefinementPipeline(pm, options); + pm.addPass(Torch::createRefinePublicReturnPass()); + pm.addNestedPass(createCanonicalizerPass()); + // The decompose pass is run again here since the scalarize shapes pass and + // shape refinement pipeline might create some ops for which decomposition + // exists. + if (options.decompose) { + pm.addNestedPass( + Torch::createDecomposeComplexOpsPass(options.backendLegalOps)); + pm.addNestedPass(createCanonicalizerPass()); + } +} + // A simplification pipeline to establish the invariants of the backend // contract (see `satisfiedBackendContract` in `LowerToBackendContract`). // diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index 93a44ac33adc..c7ff95270d98 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -9,12 +9,9 @@ #include "PassDetail.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -78,14 +75,13 @@ class PrepareForGlobalizeObjectGraphPass func::CallIndirectOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); - // Use applyPatternsAndFoldGreedily because the CallIndirectOp folding + // Use applyPatternsGreedily because the CallIndirectOp folding // makes the ConstantOp unused, which does not work with the visitation // order of the dialect conversion infrastructure. // TODO: Do this with the dialect conversion infrastructure to avoid doing // folding as part of this. Or avoid folding during greedy pattern // application. See: https://llvm.org/PR49502 - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index c1e476a80a10..f0478686b7f9 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -13,6 +13,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include using namespace mlir; using namespace mlir::torch; @@ -164,7 +165,7 @@ class RecomposeUnbindListUnpack : public OpRewritePattern { LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindOp + PrimListUnpackOp to select.int - auto unbindOp = dyn_cast(op.getOperand().getDefiningOp()); + auto unbindOp = op.getOperand().getDefiningOp(); if (!unbindOp) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbindOp.getResult())) @@ -207,7 +208,7 @@ class RecomposeUnbindGetItem : public OpRewritePattern { LogicalResult matchAndRewrite(Aten__Getitem__TOp op, PatternRewriter &rewriter) const override { // recompose AtenUnbindIntOp + __getitem__t to select.int - auto unbind = dyn_cast(op.getList().getDefiningOp()); + auto unbind = op.getList().getDefiningOp(); if (!unbind) return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp"); if (isListPotentiallyMutated(unbind.getResult())) @@ -243,15 +244,58 @@ class RecomposeUnbindGetItem : public OpRewritePattern { } }; -class RecomposeSplitTensorGetItemOp +class RecomposeSplitTensorPrimListUnpackOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + + auto torchList = op.getOperand(); + if (isListPotentiallyMutated(torchList)) + return failure(); + + auto split = torchList.getDefiningOp(); + if (!split) + return failure(); + int64_t size = 0; + if (!matchPattern(split.getSplitSize(), m_TorchConstantInt(&size))) + return failure(); + + Value constOne = rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(1)); + std::vector results; + int64_t start = 0; + + for (size_t i = 0; i < op->getNumResults(); ++i) { + results.push_back(rewriter.create( + op->getLoc(), op.getResult(i).getType(), split.getSelf(), + /*dim=*/split.getDim(), + /*start=*/ + rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(start)), + /*end=*/ + rewriter.create( + op->getLoc(), rewriter.getI64IntegerAttr(start + size)), + /*step=*/constOne)); + start += size; + } + rewriter.replaceOp(op, results); + if (split->use_empty()) + rewriter.eraseOp(split); + + return success(); + } +}; + +class RecomposeSplitTensorGetItem : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(Aten__Getitem__TOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp - auto splitTensorOp = - dyn_cast(op.getList().getDefiningOp()); + auto splitTensorOp = op.getList().getDefiningOp(); if (!splitTensorOp) return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) @@ -308,8 +352,7 @@ class RecomposeSplitTensorListUnpack LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps - auto splitTensorOp = - dyn_cast(op.getOperand().getDefiningOp()); + auto splitTensorOp = op.getOperand().getDefiningOp(); if (!splitTensorOp) return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp"); if (isListPotentiallyMutated(splitTensorOp.getResult())) @@ -362,6 +405,78 @@ class RecomposeSplitTensorListUnpack } }; +class RecomposeSplitWithSizesGetItem + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten__Getitem__TOp op, + PatternRewriter &rewriter) const override { + // recompose AtenSplitWithSizes + __getitem__t to AtenSliceTensorOp + auto splitWithSizesOp = op.getList().getDefiningOp(); + if (!splitWithSizesOp) + return rewriter.notifyMatchFailure(op, + "Input is not AtenSplitWithSizesOp"); + if (isListPotentiallyMutated(splitWithSizesOp.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenSplitWithSizesOp result is potentially mutated"); + if (isListPotentiallyMutated(splitWithSizesOp.getSplitSizes())) { + return rewriter.notifyMatchFailure( + op, "splitWithSizesOp's split_sizes is potentially mutated"); + } + + SmallVector splitSizes; + if (!matchPattern(splitWithSizesOp.getSplitSizes(), + m_TorchListOfConstantInts(splitSizes))) { + return rewriter.notifyMatchFailure( + op, "split_sizes must be list of constant int"); + } + + int64_t index; + if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + index = toPositiveDim(index, splitSizes.size()); + if (!isValidDim(index, splitSizes.size())) + return rewriter.notifyMatchFailure( + op, "Expected `idx` in range of split_sizes"); + + Location loc = op.getLoc(); + Value input = splitWithSizesOp.getSelf(); + Value dim = splitWithSizesOp.getDim(); + + // add runtime.assert to check dimension constraint + Value totalSize = rewriter.create(loc, input, dim); + int64_t sumSplitSize = + std::accumulate(splitSizes.begin(), splitSizes.end(), 0); + Value cstSumSplitSize = rewriter.create( + loc, rewriter.getI64IntegerAttr(sumSplitSize)); + Value eqOrNot = + rewriter.create(loc, totalSize, cstSumSplitSize); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("split dim must be sum of split_sizes")); + + // replace with AtenSliceTensorOp + SmallVector boundaryOfSliceOp(splitSizes.size() + 1, 0); + for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) { + boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1]; + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = rewriter.create( + loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index])); + auto end = rewriter.create( + loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index + 1])); + Value slice = rewriter.create( + loc, op.getType(), input, dim, start, end, /*step=*/cstOne); + rewriter.replaceOp(op, slice); + // erase splitOp if no user left + if (splitWithSizesOp.getResult().use_empty()) + rewriter.eraseOp(splitWithSizesOp); + return success(); + } +}; + class RecomposeSplitWithSizesListUnpack : public OpRewritePattern { public: @@ -369,8 +484,7 @@ class RecomposeSplitWithSizesListUnpack LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps - auto splitOp = - dyn_cast(op.getOperand().getDefiningOp()); + auto splitOp = op.getOperand().getDefiningOp(); if (!splitOp) { return rewriter.notifyMatchFailure(op, "Input is not AtenSplitWithSizesOp"); @@ -390,20 +504,11 @@ class RecomposeSplitWithSizesListUnpack op, "split_sizes is not from PrimListConstructOp"); } - int64_t sumSplitSize = 0; SmallVector splitSizes; - for (auto operand : splitSizesConstruct.getOperands()) { - int64_t value = -1; - // TODO: support when split_sizes are not constant int - if (!matchPattern(operand, m_TorchConstantInt(&value))) { - return rewriter.notifyMatchFailure( - op, "one of split_sizes is not constant int"); - } - if (value < 0) { - return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0"); - } - sumSplitSize += value; - splitSizes.push_back(value); + if (!matchPattern(splitOp.getSplitSizes(), + m_TorchListOfConstantInts(splitSizes))) { + return rewriter.notifyMatchFailure( + op, "split_sizes must be list of constant int"); } if (splitSizes.size() != op.getNumResults()) { return rewriter.notifyMatchFailure( @@ -416,6 +521,8 @@ class RecomposeSplitWithSizesListUnpack // add runtime.assert to check rank constraint Value totalSize = rewriter.create(loc, input, dim); + int64_t sumSplitSize = + std::accumulate(splitSizes.begin(), splitSizes.end(), 0); Value cstSumSplitSize = rewriter.create( loc, rewriter.getI64IntegerAttr(sumSplitSize)); Value eqOrNot = @@ -450,13 +557,156 @@ class RecomposeSplitWithSizesListUnpack } }; +class RecomposeTensorSplitSectionsGetItem + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten__Getitem__TOp op, + PatternRewriter &rewriter) const override { + // recompose AtenTensorSplitSectionsOp + __getitem__t to AtenSliceTensorOp + auto splitOp = op.getList().getDefiningOp(); + if (!splitOp) + return rewriter.notifyMatchFailure( + op, "Input is not AtenTensorSplitSectionsOp"); + if (isListPotentiallyMutated(splitOp.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenTensorSplitSectionsOp result is potentially mutated"); + + int64_t sections; + if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(§ions))) + return rewriter.notifyMatchFailure( + op, "Expected `sections` of AtenTensorSplitSectionsOp to be a " + "constant int"); + + int64_t index; + if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index))) + return rewriter.notifyMatchFailure( + op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int"); + index = toPositiveDim(index, sections); + if (!isValidDim(index, sections)) + return rewriter.notifyMatchFailure( + op, "Expected `idx` in range of split_sizes"); + + Location loc = op.getLoc(); + Value input = splitOp.getSelf(); + Value dim = splitOp.getDim(); + + // only recompose to slice when split dim size is static, otherwise we need + // control flow like prim.if + Value dimSizeValue = rewriter.createOrFold(loc, input, dim); + int64_t splitDimSize; + if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize))) + return rewriter.notifyMatchFailure(splitOp, + "split dim size must be static"); + + int64_t chunkSize = splitDimSize / sections; + int64_t remain = splitDimSize % sections; + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value result; + if (index < remain) { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(index * (chunkSize + 1))); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((index + 1) * (chunkSize + 1))); + result = rewriter.create(loc, op.getType(), input, dim, + start, end, + /*step=*/cstOne); + } else { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(index * chunkSize + remain)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((index + 1) * chunkSize + remain)); + result = rewriter.create(loc, op.getType(), input, dim, + start, end, + /*step=*/cstOne); + } + rewriter.replaceOp(op, result); + // erase AtenTensorSplitSectionsOp if no user left + if (splitOp.getResult().use_empty()) + rewriter.eraseOp(splitOp); + return success(); + } +}; + +class RecomposeTensorSplitSectionsListUnpack + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + // recompose AtenTensorSplitSectionsOp + PrimListUnpackOp to + // AtenSliceTensorOps + auto splitOp = op.getOperand().getDefiningOp(); + if (!splitOp) + return rewriter.notifyMatchFailure( + op, "Input is not AtenTensorSplitSectionsOp"); + if (isListPotentiallyMutated(splitOp.getResult())) + return rewriter.notifyMatchFailure( + op, "AtenTensorSplitSectionsOp result is potentially mutated"); + + int64_t sections; + if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(§ions))) + return rewriter.notifyMatchFailure( + op, "Expected `sections` of AtenTensorSplitSectionsOp to be a " + "constant int"); + if (op->getNumResults() != sections) + return rewriter.notifyMatchFailure( + op, "`sections` must be same as ListUnpack's NumResults"); + + Location loc = op.getLoc(); + Value input = splitOp.getSelf(); + Value dim = splitOp.getDim(); + + // only recompose to slice when split dim size is static, otherwise we need + // control flow like prim.if + Value dimSizeValue = rewriter.createOrFold(loc, input, dim); + int64_t splitDimSize; + if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize))) + return rewriter.notifyMatchFailure(splitOp, + "split dim size must be static"); + + int64_t chunkSize = splitDimSize / sections; + int64_t remain = splitDimSize % sections; + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + SmallVector results; + for (int64_t i = 0; i < sections; i++) { + if (i < remain) { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(i * (chunkSize + 1))); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((i + 1) * (chunkSize + 1))); + Value slice = rewriter.create( + loc, op.getResult(i).getType(), input, dim, start, end, + /*step=*/cstOne); + results.push_back(slice); + } else { + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(i * chunkSize + remain)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr((i + 1) * chunkSize + remain)); + Value slice = rewriter.create( + loc, op.getResult(i).getType(), input, dim, start, end, + /*step=*/cstOne); + results.push_back(slice); + } + } + rewriter.replaceOp(op, results); + // erase AtenTensorSplitSectionsOp if no user left + if (splitOp.getResult().use_empty()) + rewriter.eraseOp(splitOp); + return success(); + } +}; + class RecomposeChunkListUnpack : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimListUnpackOp op, PatternRewriter &rewriter) const override { // recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps - auto chunkOp = dyn_cast(op.getOperand().getDefiningOp()); + auto chunkOp = op.getOperand().getDefiningOp(); if (!chunkOp) return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp"); if (isListPotentiallyMutated(chunkOp.getResult())) @@ -470,10 +720,13 @@ class RecomposeChunkListUnpack : public OpRewritePattern { // chunkSize = floordiv(totalSize + chunks - 1, chunks) Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks); - // add runtime.assert to check chunks == NumResults + // add runtime.assert to check floordiv(totalSize + chunkSize - 1, + // chunkSize) == NumResults Value cstNumResults = rewriter.create( loc, rewriter.getI64IntegerAttr(op.getNumResults())); - Value eqOrNot = rewriter.create(loc, chunks, cstNumResults); + Value realChunks = getIntCeilDiv(rewriter, loc, totalSize, chunkSize); + Value eqOrNot = + rewriter.create(loc, realChunks, cstNumResults); rewriter.create( loc, eqOrNot, rewriter.getStringAttr( @@ -510,6 +763,81 @@ class RecomposeChunkListUnpack : public OpRewritePattern { }; } // namespace +namespace { +class RecomposeMeshgridIndexingListUnpack + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimListUnpackOp op, + PatternRewriter &rewriter) const override { + auto meshgridIndexingOp = + op.getOperand().getDefiningOp(); + if (!meshgridIndexingOp) + return rewriter.notifyMatchFailure(op, + "Input is not AtenMeshgridIndexingOp"); + Location loc = meshgridIndexingOp.getLoc(); + auto context = meshgridIndexingOp.getContext(); + auto baseType = NonValueTensorType::getWithLeastStaticInformation(context); + SmallVector tensors; + if (!getListConstructElements(meshgridIndexingOp.getTensors(), tensors)) + return rewriter.notifyMatchFailure(meshgridIndexingOp, + "Unable to get tensors"); + + int64_t numTensors = tensors.size(); + bool swapFirstAndSecondTensors = false; + + std::string indexing; + if (!matchPattern(meshgridIndexingOp.getIndexing(), + m_TorchConstantStr(indexing))) { + return rewriter.notifyMatchFailure(meshgridIndexingOp, + "Unable to get indexing"); + } + + if (indexing == "xy" && numTensors >= 2) { + swapFirstAndSecondTensors = true; + std::swap(tensors[0], tensors[1]); + } + + SmallVector expandShapeValues; + for (int64_t i = 0; i < numTensors; i++) { + expandShapeValues.push_back( + rewriter.create(loc, tensors[i])); + } + Value expandShapeList = rewriter.create( + loc, ListType::get(IntType::get(context)), expandShapeValues); + + SmallVector meshgrids; + Value constFalse = + rewriter.create(loc, rewriter.getBoolAttr(false)); + + for (auto [idx, tensor] : llvm::enumerate(tensors)) { + Value constantOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + SmallVector tensorViewShapeValues(numTensors, constantOne); + tensorViewShapeValues[idx] = expandShapeValues[idx]; + + Value viewShapeList = rewriter.create( + loc, ListType::get(IntType::get(context)), tensorViewShapeValues); + Value view = + rewriter.create(loc, baseType, tensor, viewShapeList); + + Value expandView = rewriter.create( + loc, baseType, view, expandShapeList, constFalse); + meshgrids.push_back(expandView); + } + + if (swapFirstAndSecondTensors) { + std::swap(meshgrids[0], meshgrids[1]); + } + rewriter.replaceOp(op, meshgrids); + // erase meshgridIndexingOp if no user left + if (meshgridIndexingOp.getResult().use_empty()) + rewriter.eraseOp(meshgridIndexingOp); + return success(); + } +}; +} // namespace + namespace { class RecomposeComplexOpsPass : public RecomposeComplexOpsBase { @@ -521,19 +849,27 @@ class RecomposeComplexOpsPass // pattern.add calls go here patterns.add(context); patterns.add(context); - patterns.add(context); + + // TODO: cloud move these patterns to Decompose pass, but should handle + // shape and value semantics carefully + patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); + patterns.add(context); + patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); patterns.add(context); + patterns.add(context); GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 8b758a135751..5712b66f6c1d 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -118,7 +118,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { if (auto optionalType = dyn_cast(listType.getContainedType())) { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { - return val.getType().isa(); + return isa(val.getType()); })) { rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( @@ -246,6 +246,9 @@ void TorchMatchSpecializedBackendOp::populateSpecializedConversions( llvm::SmallVector newOperands{ oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5], oldOperands[3], oldOperands[4], oldOperands[6]}; + Value enableGQA = + rewriter.create(op->getLoc(), false); + newOperands.push_back(enableGQA); auto newOp = rewriter.create( op.getLoc(), op->getResultTypes()[0], newOperands, diff --git a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp index 373680495f41..6f45e8876ee1 100644 --- a/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp +++ b/lib/Dialect/Torch/Transforms/RefinePublicReturn.cpp @@ -9,8 +9,6 @@ #include "PassDetail.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 3b25e12c3a8e..cd6126aa4da5 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -81,7 +81,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( if (name.starts_with("valsem.")) name = name.drop_front(strlen("valsem.")); if (isa(op)) - name = cast(op)->getAttr("name").cast().getValue(); + name = cast(cast(op)->getAttr("name")).getValue(); std::string libFuncName = (getLibraryFunctionPrefix(libFuncKind) + Twine(name)).str(); auto libFunc = library.lookupSymbol(libFuncName); @@ -191,8 +191,8 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // to match the library function signature. if (auto unionType = dyn_cast(desiredType)) { if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) { - return containedType - .isa(); + return isa( + containedType); })) return b.create(loc, desiredType, operand).getResult(); } diff --git a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp new file mode 100644 index 000000000000..bd6b1daaf99d --- /dev/null +++ b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp @@ -0,0 +1,277 @@ +//===- RestructureNonConstantAxes.cpp --------------------------------*- +// C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "torch-lower-to-backend-contract" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +template +class ConstantifyDimArgument : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + bool isDimConstant(SrcOp op) const { + SmallVector dimList; + int64_t dim; + return matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList)) || + matchPattern(op.getDim(), m_TorchConstantInt(&dim)); + } + + /* + This function renders the reduction dim constant by reshaping the input tensor + such that the dim argument is the middle dimension. + + For example, if the input tensor has shape [3,4,5,6,7] and the dim argument is + -2, the input tensor is reshaped to [3,4,5,6,7] -> [12,5,42], the reduction + operation is applied, and the result is reshaped back to [3,4,1,6,7]. + + Since we don't know the dim argument at compile time, we need to compute the + arguments to the reshape op at runtime. We do this by computing the new shape + of the tensor by multiplying the shapes of the tensor before and after the dim + argument, and then reshaping the tensor to this new shape. + */ + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + Value self = op.getSelf(); + Value dim = op.getDim(); + + if (isDimConstant(op)) { + return rewriter.notifyMatchFailure(op, + "dim argument is already constant"); + } + + if (isa(dim.getType())) { + return rewriter.notifyMatchFailure( + op, "RestructureNonConstantAxes does not support None dim"); + } + + // when keepdim is not constant, check the ranks of the input and output + // tensors + ValueTensorType selfTy = + llvm::cast(op.getSelf().getType()); + ValueTensorType resultTy = + llvm::cast(op.getResult().getType()); + if (selfTy.hasSizes() && resultTy.hasSizes() && + selfTy.getSizes().size() != resultTy.getSizes().size()) { + return rewriter.notifyMatchFailure( + op, + "RestructureNonConstantAxes does not yet support keepdim=false, but " + "the input and output tensors have different ranks"); + } + + Type intType = rewriter.getType(); + Type boolType = rewriter.getType(); + auto createInt = [&](int value) { + return rewriter.create( + loc, intType, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), value)); + }; + Value zero = createInt(0); + Value one = createInt(1); + + // handle when dim is a single element list + bool oldDimIsList = isa(dim.getType()); + if (oldDimIsList) { + Value len = rewriter.create(loc, intType, dim); + Value dimListIsLengthOne = + rewriter.create(loc, boolType, len, one); + rewriter.create( + loc, dimListIsLengthOne, + rewriter.getStringAttr("RestructureNonConstantAxes does not support " + "dim lists with more than one element")); + dim = rewriter.create(loc, intType, dim, zero); + } + + // Normalize negative dim + Value rank = rewriter.create(loc, intType, self); + Value isNegative = rewriter.create(loc, dim, zero); + Value rankOffset = rewriter.create( + loc, intType, + rewriter.create(loc, intType, isNegative), rank); + dim = rewriter.create(loc, intType, dim, rankOffset); + + auto createConditionalMult = [&](Value self, Value multiplier, + Value condition) { + // compute: + // result = codition ? (self * multiplier) : self + // via + // result = self * (1 + (multiplier - 1) * condition) + // which translates to: + + // result = multiplier - 1 + Value result = rewriter.create( + loc, intType, multiplier, createInt(1)); + // result = result * condition + result = + rewriter.create(loc, intType, result, condition); + // result = result + 1 + result = rewriter.create(loc, intType, result, + createInt(1)); + // result = self * result + result = rewriter.create(loc, intType, self, result); + return result; + }; + + // new shape = [beforeDim, dimSize, afterDim] + Value beforeProd = createInt(1); + Value afterProd = createInt(1); + Value dimSize = createInt(1); + + for (size_t i = 0; i < selfTy.getSizes().size(); ++i) { + Value idx = createInt(i); + Value size = + rewriter.create(loc, intType, self, idx); + + Value isBeforeDim = + rewriter.create(loc, boolType, idx, dim); + isBeforeDim = + rewriter.create(loc, intType, isBeforeDim); + Value isAfterDim = + rewriter.create(loc, boolType, idx, dim); + isAfterDim = + rewriter.create(loc, intType, isAfterDim); + + Value isEqualToDim = + rewriter.create(loc, boolType, idx, dim); + isEqualToDim = + rewriter.create(loc, intType, isEqualToDim); + dimSize = createConditionalMult(dimSize, size, isEqualToDim); + + beforeProd = createConditionalMult(beforeProd, size, isBeforeDim); + afterProd = createConditionalMult(afterProd, size, isAfterDim); + } + + Value newShape = rewriter.create( + loc, rewriter.getType(intType), + ValueRange{beforeProd, dimSize, afterProd}); + + // Reshape input + auto newSelfTy = selfTy.getWithSizesAndDtype( + SmallVector{Torch::kUnknownSize, Torch::kUnknownSize, + Torch::kUnknownSize}, + selfTy.getDtype()); + Value reshapedSelf = + rewriter.create(loc, newSelfTy, self, newShape); + + // construct new operange range where self is replaced with reshapedSelf + // tensor, and dim is replaced with 1 + Value newDim; + if (oldDimIsList) { + newDim = rewriter.create( + loc, rewriter.getType(intType), ValueRange{one}); + } else { + newDim = one; + } + ValueRange oldOperands = op->getOperands(); + SmallVector newOperandsVect; + for (size_t i = 0; i < oldOperands.size(); ++i) { + if (oldOperands[i] == op.getSelf()) { + newOperandsVect.push_back(reshapedSelf); + } else if (oldOperands[i] == op.getDim()) { + newOperandsVect.push_back(newDim); + } else { + newOperandsVect.push_back(oldOperands[i]); + } + } + ValueRange newOperands = ValueRange(newOperandsVect); + + // construct new reduction op result type + ValueTensorType newResultTy = + cast(resultTy.getWithSizesAndDtype( + SmallVector{Torch::kUnknownSize, 1, Torch::kUnknownSize}, + resultTy.getDtype())); + + Value newReductionOp = + rewriter.create(loc, newResultTy, newOperands, op->getAttrs()); + + // Reshape the result back to original shape + ValueTensorType oldResultTy = + cast(op.getResult().getType()); + SmallVector shapeValues; + for (auto dim : oldResultTy.getSizes()) { + shapeValues.push_back(createInt(dim)); + } + Value originalShape = rewriter.create( + loc, rewriter.getType(intType), shapeValues); + Value result = rewriter.create( + loc, op->getResult(0).getType(), newReductionOp, originalShape); + + rewriter.replaceOp(op, result); + return success(); + }; +}; + +template +void addConstantifyDimArgumentPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + // simple variadic template to sugar up adding the patterns + (patterns.add>(context), ...); +} + +void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, + MLIRContext *context) { + // these are the reduction ops with a dim argument + + addConstantifyDimArgumentPatterns< + // not supported because they have multiple results + // AtenMaxDimOp, + // AtenMinDimOp, + AtenSumDimIntListOp, AtenAllDimOp, AtenLinalgVectorNormOp, + AtenFrobeniusNormDimOp>(patterns, context); +} + +class RestructureNonConstantAxesPass + : public RestructureNonConstantAxesBase { +public: + RestructureNonConstantAxesPass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + RewritePatternSet patterns(context); + + populateRestructureNonConstantAxesPattern(patterns, context); + + // TODO: Debug visitation order to make this more efficient. + // A single linear scan should suffice. + GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.maxIterations = GreedyRewriteConfig::kNoLimit; + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createRestructureNonConstantAxesPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index cb33b75fee03..0914d5b0eed6 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -9,14 +9,15 @@ #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Iterators.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; @@ -27,7 +28,7 @@ namespace { LogicalResult materializeFolds(ImplicitLocOpBuilder b, ArrayRef fold, - SmallVector &values) { + SmallVectorImpl &values) { for (auto f : fold) { if (auto val = dyn_cast(f)) { values.push_back(val); @@ -36,14 +37,14 @@ LogicalResult materializeFolds(ImplicitLocOpBuilder b, if (auto attr = dyn_cast(f)) { if (auto val = dyn_cast(attr)) { - values.push_back(b.create( - b.getType(), val)); + values.push_back( + b.create(APFloat(val.getValueAsDouble()))); continue; } if (auto val = dyn_cast(attr)) { values.push_back( - b.create(b.getType(), val)); + b.create(val.getValue().getSExtValue())); continue; } } @@ -65,10 +66,14 @@ LogicalResult getListOperands(Value value, SmallVector &vals) { return success(); } -LogicalResult getListFromTensor(Value value, SmallVector &vals) { +LogicalResult getListFromTensor(Value value, SmallVector &vals) { constexpr int64_t kMaxFold = 16; - if (auto tensor = value.getDefiningOp()) - return getListOperands(tensor.getData(), vals); + if (auto tensor = value.getDefiningOp()) { + SmallVector unfolded; + LogicalResult gotList = getListOperands(tensor.getData(), unfolded); + vals = getAsOpFoldResult(unfolded); + return gotList; + } if (auto full = value.getDefiningOp()) { auto ty = cast(full.getType()); @@ -78,14 +83,126 @@ LogicalResult getListFromTensor(Value value, SmallVector &vals) { if (ty.getSizes()[0] > kMaxFold) return failure(); - vals.resize(vals.size() + ty.getSizes()[0], full.getFillValue()); + vals.resize(vals.size() + ty.getSizes()[0], + getAsOpFoldResult(full.getFillValue())); + return success(); + } + + if (auto unsqueeze = value.getDefiningOp()) { + Value usqSelf = unsqueeze.getSelf(); + if (auto numToTensor = + usqSelf.getDefiningOp()) { + vals.push_back(getAsOpFoldResult(numToTensor.getA())); + return success(); + } + } + + // A common rank 0 tensor producer + if (auto numToTensor = + value.getDefiningOp()) { + vals.push_back(getAsOpFoldResult(numToTensor.getA())); + return success(); + } + + // Last supported case: ValueTensorLiteralOp + auto literalOp = value.getDefiningOp(); + if (!literalOp) + return failure(); + + // Check the type. + auto ty = cast(literalOp.getType()); + if (!ty.hasSizes() || ty.getSizes().size() > 1) + return failure(); + // make sure the type is not unsigned here before trying to materialize + auto intTy = dyn_cast_or_null(ty.getDtype()); + if (!intTy || intTy.isUnsigned()) + return failure(); + + // if we have a rank 0 literal, we will be adding one element to the list + int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1; + + if (listSize > kMaxFold) + return failure(); + + // check for a splat or dense attr + auto splattr = dyn_cast_or_null(literalOp.getValue()); + auto denseAttr = dyn_cast_or_null(literalOp.getValue()); + + if (!splattr && !denseAttr) + return failure(); + + // These are not mutually exclusive, so try splat first. + if (splattr) { + auto attr = splattr.getSplatValue(); + vals.resize((int64_t)vals.size() + listSize, attr); return success(); } - return failure(); + // remaining case: denseAttr + if ((int64_t)denseAttr.getValues().size() != listSize) + return failure(); + for (auto e : denseAttr.getValues()) + vals.push_back(e); + return success(); +} + +Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy, + SmallVector &listValues) { + auto dimList = b.create( + b.getType(listValues.front().getType()), listValues); + Value cstNone = b.create(); + Value cstFalse = b.create(b.getBoolAttr(false)); + return b.create(resultTy, dimList, cstNone, cstNone, + cstFalse); } } // namespace +/// ------ Propagation Patterns ------ /// +// The general goal of these patterns is to convert SomeTensorOp to [scalarOps +// -> PrimListOfInts -> AtenTensorOp] Since these tensorized shape calculation +// ops are chained together, sequences like OpA -> OpB will propagate OpA first: +// [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to +// getListFromTensor(A), and further propagate scalarization. + +namespace { +class PropagateAtenBroadcastToPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBroadcastToOp op, + PatternRewriter &rewriter) const override { + constexpr int64_t kMaxFold = 16; + // for tensor, or tensor<1xsi64>, broadcasted to tensor, grab + // the element and convert to a full op. + auto ty = cast(op.getType()); + if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1) + return failure(); + + if (ty.getSizes()[0] > kMaxFold) + return failure(); + + SmallVector fillFold; + if (failed(getListFromTensor(op.getSelf(), fillFold)) || + fillFold.size() != 1) + return failure(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector fillVals; + if (failed(materializeFolds(b, fillFold, fillVals))) + return failure(); + + Value size = b.create(ty.getSizes().front()); + Value sizeList = b.create( + rewriter.getType(rewriter.getType()), + size); + Value none = b.create(); + Value cstFalse = b.create(false); + rewriter.replaceOpWithNewOp(op, ty, sizeList, fillVals.front(), + none, none, none, cstFalse); + return success(); + } +}; +} // namespace + namespace { class PropagateAtenShapeToTensorPattern : public OpRewritePattern { @@ -94,30 +211,27 @@ class PropagateAtenShapeToTensorPattern LogicalResult matchAndRewrite(Aten_ShapeAsTensorOp op, PatternRewriter &rewriter) const override { auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); auto self = op.getSelf(); auto selfTy = cast(self.getType()); if (!selfTy.hasSizes()) return rewriter.notifyMatchFailure(op, "self has unknown rank"); int64_t rank = selfTy.getSizes().size(); - SmallVector dims; + SmallVector dims; for (int64_t i = 0; i < rank; ++i) { - auto iv = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - dims.push_back(rewriter.create( - loc, rewriter.getType(), self, iv)); + auto iv = b.create(i); + dims.push_back(b.createOrFold( + rewriter.getType(), self, iv)); + } + SmallVector materializedDims; + if (failed(materializeFolds(b, dims, materializedDims))) { + return failure(); } - auto dimList = rewriter.create( - loc, - rewriter.getType(rewriter.getType()), - dims); - - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), dimList, cstNone, cstNone, cstFalse); + Value result = + constructAtenTensorOpFromList(b, op.getType(), materializedDims); + rewriter.replaceOp(op, result); return success(); } }; @@ -150,56 +264,20 @@ class PropagateAtenCatPattern : public OpRewritePattern { SmallVector scalars; for (auto element : tensors) { - llvm::SmallVector delisted; - if (succeeded(getListFromTensor(element, delisted))) { - for (auto scalar : delisted) - scalars.push_back(scalar); - continue; - } - - DenseElementsAttr attr; - if (matchPattern(element, m_Constant(&attr))) { - if (attr.isSplat()) { - scalars.resize(scalars.size() + attr.getNumElements(), - attr.getSplatValue()); - continue; - } - - for (auto e : attr.getValues()) { - scalars.push_back(e); - } - continue; - } + llvm::SmallVector delisted; + if (failed(getListFromTensor(element, delisted))) + return rewriter.notifyMatchFailure(op, "unknown op fold type"); - return rewriter.notifyMatchFailure(op, "unknown op fold type"); - } - - for (auto &scalar : scalars) { - if (auto attr = dyn_cast(scalar)) { - if (auto iattr = dyn_cast(attr)) { - auto i64 = iattr.getValue().getSExtValue(); - scalar = rewriter.getI64IntegerAttr(i64); - } - } + for (auto scalar : delisted) + scalars.push_back(scalar); } SmallVector values; - if (failed(materializeFolds(b, scalars, values))) + if (failed(materializeFolds(b, scalars, values)) || values.empty()) return rewriter.notifyMatchFailure(op, "unable to materialize constants"); - Type eTy = b.getType(); - if (isa(resultTy.getDtype())) - eTy = rewriter.getType(); - - auto elementsList = b.create( - rewriter.getType(eTy), values); - - Value cstNone = b.create(); - Value cstFalse = - b.create(rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), elementsList, cstNone, cstNone, cstFalse); - + Value result = constructAtenTensorOpFromList(b, resultTy, values); + rewriter.replaceOp(op, result); return success(); } }; @@ -215,7 +293,7 @@ class PropagateAtenIndexSelectPattern auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); - SmallVector elements; + SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); @@ -223,8 +301,8 @@ class PropagateAtenIndexSelectPattern if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "requires a constant dim"); - DenseElementsAttr idx; - if (!matchPattern(op.getIndex(), m_Constant(&idx))) + SmallVector idxFolds; + if (failed(getListFromTensor(op.getIndex(), idxFolds))) return rewriter.notifyMatchFailure(op, "requires a constant index"); auto selfTy = cast(op.getSelf().getType()); @@ -233,7 +311,9 @@ class PropagateAtenIndexSelectPattern auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); - dim = dim < 0 ? dim + selfRank : dim; + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return failure(); int64_t dimLength = elements.size(); if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( @@ -247,28 +327,25 @@ class PropagateAtenIndexSelectPattern "expects unary non-dim dimension"); } - SmallVector selected; - if (idx.isSplat()) { - int64_t indexInt = idx.getSplatValue().getSExtValue(); + SmallVector selected; + for (auto idx : idxFolds) { + auto attr = dyn_cast_or_null(dyn_cast(idx)); + if (!attr) + return failure(); + int64_t indexInt = attr.getValue().getSExtValue(); indexInt = indexInt < 0 ? indexInt + dimLength : indexInt; - selected.resize(idx.getNumElements(), elements[indexInt]); - } else { - for (APInt val : idx.getValues()) { - int64_t indexInt = val.getSExtValue(); - selected.push_back(elements[indexInt]); - } + if (indexInt < 0 || indexInt >= dimLength) + return failure(); + selected.push_back(elements[indexInt]); } - auto eTy = elements.front().getType(); - - auto dimList = rewriter.create( - loc, rewriter.getType(eTy), selected); + SmallVector materializedSelected; + if (failed(materializeFolds(b, selected, materializedSelected))) + return failure(); - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), dimList, cstNone, cstNone, cstFalse); + Value result = + constructAtenTensorOpFromList(b, op.getType(), materializedSelected); + rewriter.replaceOp(op, result); return success(); } }; @@ -288,7 +365,12 @@ class PropagateAtenSliceTensorPattern auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); - SmallVector elements; + auto selfTy = cast(op.getSelf().getType()); + auto resultTy = cast(op.getType()); + if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "requires static sizes"); + + SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); @@ -305,49 +387,316 @@ class PropagateAtenSliceTensorPattern if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "requires a constant step"); - if (step < 0) - return rewriter.notifyMatchFailure(op, "requires a positive step value"); - - auto selfTy = cast(op.getSelf().getType()); auto selfShape = selfTy.getSizes(); + auto resultShape = resultTy.getSizes(); int64_t selfRank = selfShape.size(); // Correct for negative indexing: - dim = dim < 0 ? dim + selfRank : dim; + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return failure(); - int64_t dimLength = elements.size(); + int64_t dimLength = selfShape[dim]; start = start < 0 ? start + dimLength : start; end = end < 0 ? end + dimLength : end; + end = (end < 0) ? -1 : end; + end = (end < 0 && step > 0) ? 0 : end; start = start < 0 ? 0 : start; - end = end < 0 ? 0 : end; end = end > dimLength ? dimLength : end; - if (selfShape[dim] != dimLength) - return rewriter.notifyMatchFailure( - op, "dim length does not match number of elements"); + int64_t frontDimProd = 1, backDimProd = 1; + for (int64_t i = 0; i < selfRank; i++) { + if (i < dim) + frontDimProd *= selfShape[i]; + if (i > dim) + backDimProd *= selfShape[i]; + } + int64_t fullDimProd = frontDimProd * dimLength * backDimProd; + if (fullDimProd != (int64_t)elements.size()) + return rewriter.notifyMatchFailure(op, "unexpected number of elements."); + + // [d0,d1] i -> (i//d1, i % d1) -> (i//d1) * d1 + (i % d1) + // [d0,d1,d2] i -> (i//d2, i%d2) -> ((i//(d1*d2), (i//d2) % d1, i % d2) + + auto isSliceIdx = [&](int64_t i) { + int64_t dimidx = (i / backDimProd) % dimLength; + bool onStep = ((dimidx - start) % step == 0); + bool beforeEnd = (step < 0 && dimidx > end); + beforeEnd = beforeEnd || (step > 0 && dimidx < end); + bool afterBegin = (step < 0 && dimidx <= start); + afterBegin = afterBegin || (step > 0 && dimidx >= start); + return onStep && beforeEnd && afterBegin; + }; - for (int64_t i = 0; i < selfRank; ++i) { - if (i == dim) + auto flipIdx = [&](int64_t i) { + int64_t frontIdx = (i / (backDimProd * dimLength)); + int64_t dimIdx = (i / (backDimProd)) % dimLength; + int64_t flipDimIdx = dimLength - 1 - dimIdx; + int64_t backIdx = i % (backDimProd); + return frontIdx * (dimLength * backDimProd) + flipDimIdx * (backDimProd) + + backIdx; + }; + SmallVector selected; + for (int64_t i = 0; i < (int64_t)elements.size(); i++) { + if (!isSliceIdx(i)) continue; - if (selfShape[i] != 1) - return rewriter.notifyMatchFailure(op, - "expects unary non-dim dimension"); + int64_t index = (step > 0) ? i : flipIdx(i); + selected.push_back(elements[index]); } - SmallVector selected; - for (int i = start; i < end; i += step) - selected.push_back(elements[i]); + fullDimProd = (fullDimProd * resultShape[dim]) / selfShape[dim]; + if ((int64_t)selected.size() != fullDimProd) + return rewriter.notifyMatchFailure( + op, "Constructed slice values have an incompatable number of " + "elements to match the provided return type."); + + SmallVector values; + if (failed(materializeFolds(b, selected, values))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, op.getType(), values); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +class PropagateAtenTransposeIntPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTransposeIntOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + auto selfTy = cast(op.getSelf().getType()); + auto resultTy = cast(op.getType()); + if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "requires static sizes"); + + SmallVector elements; + if (failed(getListFromTensor(op.getSelf(), elements))) + return failure(); + + int64_t dim0, dim1; + if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) + return failure(); + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return failure(); + + ArrayRef selfSizes = selfTy.getSizes(); + int64_t rank = selfSizes.size(); + + dim0 = toPositiveDim(dim0, rank); + dim1 = toPositiveDim(dim1, rank); + if (!isValidDim(dim0, rank) || !isValidDim(dim0, rank)) + return failure(); + + if (dim0 == dim1) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + + if (dim0 > dim1) { + // swap dim0 and dim1 + dim0 = dim0 + dim1; + dim1 = dim0 - dim1; + dim0 -= dim1; + } + + // A generic transpose will look like... + // [frontDimsFlat, dim0, midDimsFlat, dim1, backDimsFlat] -> . + // [frontDimsFlat, dim1, midDimsFlat, dim0, backDimsFlat] . + // If any of front, mid, or back don't actually exist (e.g. dim0 = 0, or + // dim1 = dim0 + 1), the reassociation of completely flattened indices will + // remain unaffected by the artificially unsqueezed dims. + // -------- + // Setting some notation, let D0,D1,D2,D3,D4 be the respective dim sizes of + // "self". Let D'j be the transpose dim sizes, and Djk = Dj*Dk. Let fl_trans + // and fl_self be 1-D flattened tensors. Then: + // -------- + // fl_trans[i] = + // = trans[i/D'1234, i/(D'234) % D'1, i/(D'34) % D'2, i/D'4 % D'3, i % D'4] + // = trans[i/D1234, i/D214 % D3, i/D14 % D2, i/D4 % D1, i % D4] + // = self[i/D1234, i/D4 % D1, i/D14 % D2, i/D214 % D3, i % D4] + // = fl_self[dot.prod(indices, (D1234,D234,D34,D4,1))] . + // -------- + // reassoc(i) = (i/(D1234)) * D1234 + + // (i/D4 % D1) * D234 + + // (i/(D14) % D2) * D34 + + // (i/(D214) % D3) * D4 + + // (i % D4) . + + SmallVector D(5, 1); + int64_t i = -1; + // D[0] corresponds to flattened front dims + while (++i < dim0) + D[0] *= selfSizes[i]; + // D[1] is the earliest transpose dim + D[1] = selfSizes[i]; + // D[2] corresponds to flattened middle dims + while (++i < dim1) + D[2] *= selfSizes[i]; + // D[3] is the later transpose dim + D[3] = selfSizes[i]; + // D[4] corresponds to flattened back dims + while (++i < rank) + D[4] *= selfSizes[i]; + + int64_t D1234 = D[1] * D[2] * D[3] * D[4]; + int64_t fullDP = D[0] * D1234; + if (fullDP != (int64_t)elements.size()) + return failure(); + auto reassoc = [&](int64_t i) { + return (i / D1234) * D1234 + ((i / D[4]) % D[1]) * D[2] * D[3] * D[4] + + ((i / (D[1] * D[4])) % D[2]) * D[3] * D[4] + + ((i / (D[2] * D[1] * D[4])) % D[3]) * D[4] + (i % D[4]); + }; + SmallVector transposedFolds; + transposedFolds.reserve(fullDP); + for (int64_t i = 0; i < fullDP; i++) + transposedFolds.push_back(elements[reassoc(i)]); + + SmallVector transposedVals; + if (failed(materializeFolds(b, transposedFolds, transposedVals))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, resultTy, transposedVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace +namespace { +class PropagateAtenWhereSelfPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenWhereSelfOp op, + PatternRewriter &rewriter) const override { + Value condition = op.getCondition(); + Value self = op.getSelf(); + Value other = op.getOther(); + auto conditionTy = dyn_cast(condition.getType()); + if (!conditionTy || !conditionTy.hasSizes() || + conditionTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad condition type"); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad self type"); + auto otherTy = dyn_cast(other.getType()); + if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad other type"); + int64_t conditionSize = selfTy.getSizes()[0]; + int64_t selfSize = selfTy.getSizes()[0]; + int64_t otherSize = otherTy.getSizes()[0]; + + if (selfSize != otherSize || selfSize != conditionSize) + return rewriter.notifyMatchFailure( + op, + "unimplemented: support for propogating with implicit broadcasting."); + + constexpr int64_t kMaxFold = 16; + if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold) + return rewriter.notifyMatchFailure(op, + "arguments are dynamic or too big"); + + SmallVector conditionFolds, selfFolds, otherFolds; + if (failed(getListFromTensor(condition, conditionFolds)) || + failed(getListFromTensor(self, selfFolds)) || + failed(getListFromTensor(other, otherFolds))) + return failure(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + SmallVector conditionList, selfList, otherList; + if (failed(materializeFolds(b, conditionFolds, conditionList)) || + failed(materializeFolds(b, selfFolds, selfList)) || + failed(materializeFolds(b, otherFolds, otherList))) + return failure(); + + SmallVector whereVals; + auto rank0IntTy = rewriter.getType( + ArrayRef({}), selfTy.getDtype()); + auto rank0BoolTy = rewriter.getType( + ArrayRef({}), conditionTy.getDtype()); + for (uint64_t i = 0; i < selfList.size(); i++) { + Value rank0Cond = b.create( + rank0BoolTy, conditionList[i]); + Value rank0Self = + b.create(rank0IntTy, selfList[i]); + Value rank0Other = + b.create(rank0IntTy, otherList[i]); + Value rank0Where = b.create(rank0IntTy, rank0Cond, + rank0Self, rank0Other); + whereVals.push_back( + b.create(rewriter.getType(), rank0Where)); + } + Value result = constructAtenTensorOpFromList(b, op.getType(), whereVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +class PropagateAtenEqTensorPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEqTensorOp op, + PatternRewriter &rewriter) const override { + Value self = op.getSelf(); + Value other = op.getOther(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad self type"); + auto otherTy = dyn_cast(other.getType()); + if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad other type"); + int64_t selfSize = selfTy.getSizes()[0]; + int64_t otherSize = otherTy.getSizes()[0]; + + if (selfSize != otherSize) + return rewriter.notifyMatchFailure( + op, + "unimplemented: support for propogating with implicit broadcasting."); + + constexpr int64_t kMaxFold = 16; + if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold || + otherSize == Torch::kUnknownSize || otherSize > kMaxFold) + return rewriter.notifyMatchFailure(op, + "self or other is dynamic or too big"); - auto eTy = elements.front().getType(); - auto dimList = rewriter.create( - loc, rewriter.getType(eTy), selected); + SmallVector selfFolds, otherFolds; + if (failed(getListFromTensor(self, selfFolds)) || + failed(getListFromTensor(other, otherFolds))) + return rewriter.notifyMatchFailure(op, "failed to get list from tensor"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfList, otherList; + if (failed(materializeFolds(b, selfFolds, selfList)) || + failed(materializeFolds(b, otherFolds, otherList))) + return rewriter.notifyMatchFailure(op, "failed to materialize folds"); + + SmallVector eqBoolFolds; + for (uint64_t i = 0; i < selfList.size(); i++) { + OpFoldResult eqInt = + b.createOrFold(selfList[i], otherList[i]); + if (auto eqIntVal = dyn_cast(eqInt)) + eqInt = b.createOrFold(eqIntVal); + // if eqInt was an Attribute, it will materialize to a constant int op, + // which is what we want. + eqBoolFolds.push_back(eqInt); + } + SmallVector eqVals; + if (failed(materializeFolds(b, eqBoolFolds, eqVals))) { + return failure(); + } - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - rewriter.replaceOpWithNewOp( - op, op.getType(), dimList, cstNone, cstNone, cstFalse); + Value result = constructAtenTensorOpFromList(b, op.getType(), eqVals); + rewriter.replaceOp(op, result); return success(); } }; @@ -359,20 +708,332 @@ class PropagateAtenItemPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenItemOp op, PatternRewriter &rewriter) const override { + SmallVector elements; + Value self = op.getSelf(); + auto selfTy = cast(self.getType()); ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector elements; + + // Rank 0 item op prop + if (selfTy.getSizes().empty()) { + auto numToTensor = self.getDefiningOp(); + auto squeezeDim = self.getDefiningOp(); + if (!squeezeDim && !numToTensor) + return rewriter.notifyMatchFailure(op, + "unhandled item of rank 0 operand"); + if (numToTensor) { + rewriter.replaceOp(op, numToTensor.getA()); + return success(); + } + rewriter.replaceOpWithNewOp(op, op.getType(), + squeezeDim.getSelf()); + return success(); + } + + // Rank 1 item op prop if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); if (elements.size() != 1) - return rewriter.notifyMatchFailure(op, "expected no elements"); + return rewriter.notifyMatchFailure(op, "expected one element"); + + SmallVector materialized; + if (failed(materializeFolds(b, elements, materialized))) + return failure(); - rewriter.replaceOp(op, elements[0]); + rewriter.replaceOp(op, materialized.front()); return success(); } }; } // namespace +namespace { + +LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b, + SmallVector &converted, + SmallVector &elements, + Type inputDtype, Type resultDtype) { + auto inputIsInt = dyn_cast(inputDtype); + auto resultIsInt = dyn_cast(resultDtype); + if (!inputIsInt && !isa(inputDtype)) + return failure(); + if (!resultIsInt && !isa(resultDtype)) + return failure(); + + // if dtypes are both int or both float, no conversion needed + if (static_cast(inputIsInt) == static_cast(resultIsInt)) { + converted = elements; + return success(); + } + + if (resultIsInt) { + for (auto &e : elements) { + auto eValue = dyn_cast(e); + if (eValue) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + auto eAttr = dyn_cast(e); + auto eFloatAttr = dyn_cast_or_null(eAttr); + if (!eFloatAttr) + return failure(); + + converted.push_back(IntegerAttr::get( + resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); + } + return success(); + } + + // result is float + for (auto &e : elements) { + auto eValue = dyn_cast(e); + if (eValue) { + converted.push_back(b.createOrFold(eValue)); + continue; + } + auto eAttr = dyn_cast(e); + auto eIntAttr = dyn_cast(eAttr); + if (!eIntAttr) + return failure(); + + auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() + : eIntAttr.getValue().getZExtValue(); + converted.push_back(FloatAttr::get(resultDtype, static_cast(eInt))); + } + return success(); +} + +class PropagateAtenToDtypePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenToDtypeOp op, + PatternRewriter &rewriter) const override { + bool nonBlocking, copyArg; + // The non_blocking arg must be `False`. + if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || + nonBlocking) + return failure(); + // The copy arg must be `False`. + if (!matchPattern(op.getCopy(), m_TorchConstantBool(©Arg)) || copyArg) + return failure(); + // The memory_format arg must be `none`. + if (!isa(op.getMemoryFormat().getType())) + return failure(); + + auto inputType = dyn_cast(op.getSelf().getType()); + auto resultType = dyn_cast(op.getType()); + if (!inputType || !resultType || !inputType.hasDtype() || + !resultType.hasDtype()) + return failure(); + auto inputDtype = inputType.getDtype(); + auto resultDtype = resultType.getDtype(); + + SmallVector elements; + if (failed(getListFromTensor(op.getSelf(), elements))) + return failure(); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector converted; + if (failed(convertOpFoldResults(b, converted, elements, inputDtype, + resultDtype))) + return rewriter.notifyMatchFailure( + op, "Unhandled attribute type encountered."); + + SmallVector vals; + if (failed(materializeFolds(b, converted, vals))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, op.getType(), vals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +template +class PropagateAtenViewLikePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenViewLikeOp op, + PatternRewriter &rewriter) const override { + SmallVector selfFolds; + if (failed(getListFromTensor(op.getSelf(), selfFolds))) + return failure(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfVals; + if (failed(materializeFolds(b, selfFolds, selfVals))) + return failure(); + Value result = constructAtenTensorOpFromList(b, op.getType(), selfVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { + +template struct ArithmeticHelper { + static LogicalResult getAlphaAndVerify(OpTy &op, int64_t &alpha) { + alpha = 1; + return success(); + } +}; + +template <> struct ArithmeticHelper { + static LogicalResult getAlphaAndVerify(AtenAddTensorOp &op, int64_t &alpha) { + if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1) + return failure(); + return success(); + } +}; + +template <> struct ArithmeticHelper { + static LogicalResult getAlphaAndVerify(AtenSubTensorOp &op, int64_t &alpha) { + if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1) + return failure(); + return success(); + } +}; + +template +class PropagateAtenArithmeticPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Check type + auto resultTy = cast(op.getType()); + if (resultTy.getSizes().size() > 1) + return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); + if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) + return rewriter.notifyMatchFailure(op, "not an int type"); + + int64_t alpha; + if (failed(ArithmeticHelper::getAlphaAndVerify(op, alpha))) + return rewriter.notifyMatchFailure(op, "alpha must be 1"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfFold, otherFold; + if (failed(getListFromTensor(op.getSelf(), selfFold)) || + failed(getListFromTensor(op.getOther(), otherFold)) || + selfFold.size() != otherFold.size()) + return failure(); + SmallVector selfVals, otherVals; + if (failed(materializeFolds(b, selfFold, selfVals)) || + failed(materializeFolds(b, otherFold, otherVals))) + return failure(); + SmallVector resultFolds; + for (uint64_t i = 0; i < selfVals.size(); i++) { + resultFolds.push_back(b.createOrFold( + selfVals[i].getType(), selfVals[i], otherVals[i])); + } + SmallVector resultVals; + if (failed(materializeFolds(b, resultFolds, resultVals))) + return failure(); + + if (resultTy.getSizes().empty()) { + rewriter.replaceOpWithNewOp( + op, resultTy, resultVals.front()); + return success(); + } + + Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +namespace { +template +class PropagateAtenUnaryPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Check type + auto resultTy = cast(op.getType()); + if (resultTy.getSizes().size() > 1) + return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); + if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) + return rewriter.notifyMatchFailure(op, "not an int type"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfFold; + if (failed(getListFromTensor(op.getSelf(), selfFold))) + return failure(); + SmallVector selfVals; + if (failed(materializeFolds(b, selfFold, selfVals))) + return failure(); + SmallVector resultFolds; + for (uint64_t i = 0; i < selfVals.size(); i++) { + resultFolds.push_back( + b.createOrFold(selfVals[i].getType(), selfVals[i])); + } + SmallVector resultVals; + if (failed(materializeFolds(b, resultFolds, resultVals))) + return failure(); + + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, resultTy, resultVals.front()); + return success(); + } + + Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace +/// ------ Fold Patterns ------ /// +// These are shape-specific folding patterns + +namespace { +class FoldAtenEqIntPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEqIntOp op, + PatternRewriter &rewriter) const override { + // replaces (size.int == 0) with false and adds an assert + // these comparisons are getting generated because onnx.Reshape considers 0 + // to mean "don't change this dim". However, if the size we are passing to + // onnx.Reshape is a tensor dim, this is definitely never supposed to be + // interpreted as "don't change this dim". + int64_t otherInt; + if (!matchPattern(op.getB(), m_TorchConstantInt(&otherInt)) || + otherInt != 0) + return failure(); + + // in case the shape is a product of two ints, check each + if (auto mulOp = op.getA().getDefiningOp()) { + Value self = mulOp.getA(); + Value other = mulOp.getB(); + Value selfEq = rewriter.create(op.getLoc(), self, op.getB()); + Value otherEq = + rewriter.create(op.getLoc(), other, op.getB()); + rewriter.replaceOpWithNewOp(op, selfEq, otherEq); + return success(); + } + + // if lhs is size.int op, assert size > 0 and replace with false. + if (auto sizeOp = op.getA().getDefiningOp()) { + Value selfGtOther = rewriter.create( + op.getLoc(), op.getType(), op.getA(), op.getB()); + rewriter.create( + op.getLoc(), selfGtOther, + rewriter.getStringAttr("Expected dim size > 0.")); + Value cstFalse = + rewriter.create(op.getLoc(), false); + rewriter.replaceOp(op, cstFalse); + return success(); + } + + return failure(); + } +}; +} // namespace + namespace { class FoldAtenTensorSplatPattern : public OpRewritePattern { public: @@ -399,6 +1060,11 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, op.getType(), elements.front()); + return success(); + } auto loc = op.getLoc(); SmallVector sizes; @@ -406,12 +1072,10 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); - Value one = rewriter.create( - loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), - one); + sizes); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); @@ -423,16 +1087,24 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { } // namespace namespace { -class FoldAtenSqueezePattern : public OpRewritePattern { +template +class FoldAtenSqueezePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSqueezeOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SqueezeOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "Unknown result shape"); - if (auto atenFull = op.getSelf().getDefiningOp()) { + Value self = op.getSelf(); + if (auto atenFull = self.getDefiningOp()) { + // in the rank 0 case, just return the rank 0 scalar + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, resultTy, atenFull.getFillValue()); + return success(); + } SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) sizes.push_back(rewriter.create( @@ -490,20 +1162,41 @@ class FoldAtenWhereSelf : public OpRewritePattern { if (selfSize && otherSize) { if (selfSize.getSelf() != otherSize.getSelf()) - return failure(); - - if (selfSize.getDim() != otherSize.getDim()) - return failure(); + return rewriter.notifyMatchFailure(op, "sizes not of same tensor"); + int64_t dimSelf, dimOther; + if ((selfSize.getDim() != otherSize.getDim()) && + (!matchPattern(selfSize.getDim(), m_TorchConstantInt(&dimSelf)) || + !matchPattern(otherSize.getDim(), m_TorchConstantInt(&dimOther)) || + (dimSelf != dimOther))) + return rewriter.notifyMatchFailure(op, "sizes not of same dim"); rewriter.replaceOp(op, op.getSelf()); return success(); } - return failure(); + return rewriter.notifyMatchFailure(op, "unable to fold"); } }; } // namespace +namespace { +// fold ridiculous patterns like size.int -> float.scalar -> int.scalar +class FoldAtenIntScalarPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIntScalarOp op, + PatternRewriter &rewriter) const override { + auto floatScalarOp = op.getA().getDefiningOp(); + if (!floatScalarOp) + return failure(); + auto sizeOp = floatScalarOp.getA().getDefiningOp(); + if (!sizeOp) + return failure(); + rewriter.replaceOp(op, floatScalarOp.getA()); + return success(); + } +}; +} // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: @@ -532,11 +1225,177 @@ class FoldAtenUnsqueezePattern : public OpRewritePattern { none, none, none, none); return success(); } + auto squeezeOp = op.getSelf().getDefiningOp(); + if (squeezeOp && resultTy.getSizes().size() == 1) { + rewriter.replaceOp(op, squeezeOp.getSelf()); + return success(); + } return failure(); } }; } // namespace + +/// ------ Canonicalization Patterns ------ /// + +namespace { +// This is a specific pattern for converting views like [?,...,?,lastDim] -> +// [?,...,?,factor0,factor1] to unflatten, and views like +// [?,...,?,factor0,factor1] -> [?,...,?,lastDim] to flatten, whenever it is +// possible to infer that all but last shared dim match +// TODO: move this to an actual canonicalizer for view after deleting the +// conflicting decompositions for flatten/unflatten -> view. +class CanonicalizeAtenViewPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenViewOp op, + PatternRewriter &rewriter) const override { + SmallVector viewSizes; + if (failed(getListOperands(op.getSize(), viewSizes))) + return rewriter.notifyMatchFailure( + op, "view size must be from a list construct"); + auto selfTy = dyn_cast(op.getSelf().getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "missing input type or sizes"); + auto resultTy = dyn_cast(op.getType()); + if (!resultTy || !resultTy.hasSizes() || + resultTy.getSizes().size() != viewSizes.size()) + return rewriter.notifyMatchFailure(op, "missing result type or sizes"); + int64_t inRank = selfTy.getSizes().size(); + int64_t outRank = resultTy.getSizes().size(); + + SmallVector sizes(selfTy.getSizes()); + int64_t leftMatchEnd = 0; + // compare input sizes with provided dims from left + for (; leftMatchEnd < std::min(outRank, inRank); leftMatchEnd++) { + int64_t providedSize; + bool providedStatic = matchPattern(viewSizes[leftMatchEnd], + m_TorchConstantInt(&providedSize)); + // static dim case + if (sizes[leftMatchEnd] != Torch::kUnknownSize) { + // if can't infer equality of dims, set end index and break + if (!providedStatic || providedSize != sizes[leftMatchEnd]) + break; + continue; + } + // the remaining assumes sizes[leftMatchEnd] is dynamic + // if provided dim is static, we can't match. + if (providedStatic) + break; + auto sizeIntOp = viewSizes[leftMatchEnd].getDefiningOp(); + // if we don't have a size int op on self, break + if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) + break; + int64_t dim; + // if the dim of the size int op doesn't match, fail + if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || + dim != leftMatchEnd) + break; + } + + int64_t rightMatchEnd = 0; + // compare input sizes with provided dims from right + for (; rightMatchEnd < std::min(outRank, inRank) - leftMatchEnd; + rightMatchEnd++) { + int64_t providedSize; + bool providedStatic = matchPattern(viewSizes[outRank - 1 - rightMatchEnd], + m_TorchConstantInt(&providedSize)); + // static dim case + if (sizes[inRank - 1 - rightMatchEnd] != Torch::kUnknownSize) { + // if can't infer equality of dims, set end index and break + if (!providedStatic || + providedSize != sizes[inRank - 1 - rightMatchEnd]) + break; + continue; + } + // the remaining assumes sizes[inRank - 1 - rightMatchEnd] is dynamic + // if provided dim is static, we can't match. + if (providedStatic) + break; + auto sizeIntOp = + viewSizes[outRank - 1 - rightMatchEnd].getDefiningOp(); + // if we don't have a size int op on self, break + if (!sizeIntOp || sizeIntOp.getSelf() != op.getSelf()) + break; + int64_t dim; + // if the dim of the size int op doesn't match, break + if (!matchPattern(sizeIntOp.getDim(), m_TorchConstantInt(&dim)) || + dim != inRank - 1 - rightMatchEnd) + break; + } + // the unmatched input dims start at leftMatchEnd, and end before inRank - + // rightMatchEnd + int64_t inputUnmatched = (inRank - rightMatchEnd) - leftMatchEnd; + int64_t outputUnmatched = (outRank - rightMatchEnd) - leftMatchEnd; + // if too many dims are unmatched in input/output, cannot canonicalize. + if (inputUnmatched > 1 && outputUnmatched > 1) + return rewriter.notifyMatchFailure( + op, + "View op is not simple enough to canonicalize.\n# Unmatched Input " + "dims = " + + std::to_string(inputUnmatched) + + "\n# Unmatched Output Dims = " + std::to_string(outputUnmatched) + + "\nStarting unmatched index = " + std::to_string(leftMatchEnd)); + + // if all dims match, return self. + if (inputUnmatched == outputUnmatched && + (inputUnmatched == 1 || inputUnmatched == 0)) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf()); + return success(); + } + // if input has 1 unmatched dim, and output has multiple, unflatten + if (inputUnmatched == 1 && outputUnmatched > 1) { + Value dimVal = + rewriter.create(op.getLoc(), leftMatchEnd); + SmallVector unflattenSizes(viewSizes.begin() + leftMatchEnd, + viewSizes.end() - rightMatchEnd); + // try to convert a single dynamic size input to -1 + int64_t dynCount = 0; + int64_t dynIdx = 0; + for (auto [i, v] : llvm::enumerate(unflattenSizes)) { + int64_t szeInt; + if (!matchPattern(v, m_TorchConstantInt(&szeInt))) { + dynCount++; + dynIdx = i; + continue; + } + // if we have a -1 already, make dynCount invalid and break + if (szeInt == -1) { + dynCount = -1; + break; + } + } + // if only one size is dynamic, make it -1 + if (dynCount == 1) + unflattenSizes[dynIdx] = + rewriter.create(op.getLoc(), -1); + + Value unflattenList = rewriter.create( + op.getLoc(), op.getSize().getType(), unflattenSizes); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), dimVal, unflattenList); + return success(); + } + // if multiple unmatched input dims map to one output dim, flatten + if (inputUnmatched > 1 && outputUnmatched == 1) { + Value startDim = + rewriter.create(op.getLoc(), leftMatchEnd); + // note: flatten end is inclusive for some reason. + int64_t endInt = inRank - rightMatchEnd - 1; + Value endDim = rewriter.create(op.getLoc(), endInt); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), startDim, endDim); + return success(); + } + // the remaining cases involve maximal matching dims, but mismatched ranks. + // This could only occur if squeezing or unsqueezing. + return rewriter.notifyMatchFailure( + op, "unhandled view op canonicalization to squeeze/unsqueeze."); + } +}; +} // namespace + namespace { template class RemoveUnusedPattern : public OpRewritePattern { public: @@ -553,6 +1412,124 @@ template class RemoveUnusedPattern : public OpRewritePattern { }; } // namespace +namespace { + +bool isItemForSliceOp(Operation *op) { + auto itemOp = dyn_cast_or_null(op); + if (!itemOp) + return false; + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (isa(userOp)) + return true; + } + return false; +} + +bool isSourceOpForShapeScalarization(Operation *op) { + return llvm::isa(op); +} + +bool isPrimListOfInts(Operation *op) { + auto primListOp = dyn_cast(op); + if (!primListOp) + return false; + auto listType = dyn_cast(primListOp.getType()); + if (!listType) + return false; + return llvm::isa(listType.getContainedType()); +} + +bool isAnchorOp(Operation *op) { + return isa(op) || isa(op) || + isPrimListOfInts(op) || isItemForSliceOp(op); +} + +// The argument to this function, op, is the use of some source op, srcOp. If +// this function returns true, we want to invalidate srcOp as a target for shape +// scalarization. +bool isInvalidValidViewConsumer(Operation *op, + SetVector &workList) { + // if the consumer isn't a view op, don't invalidate it + auto view = dyn_cast_or_null(op); + if (!view) + return false; + auto resultTy = dyn_cast(view.getType()); + if (!resultTy || !resultTy.hasDtype()) + return true; + // if the view op doesn't return integer types, then srcOp is not a shape + // tensor. note: prim lists will always get added before reaching this + // function call. + if (!isa(resultTy.getDtype())) + return true; + // check uses of the view op. + // If the view op has a use in our worklist, then it needs to be scalarized. + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (workList.contains(userOp)) + return false; + } + // invalidate, since the view op was added as a one-off for canonicalization. + return true; +} + +void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { + patterns.insert, + FoldAtenSqueezePattern, + FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern, + FoldAtenWhereSelf, FoldAtenTensorSplatPattern, + FoldAtenEqIntPattern>(patterns.getContext()); +} + +void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { + patterns.add>(patterns.getContext(), + /*benefit=*/10); + patterns.insert, + PropagateAtenViewLikePattern>( + patterns.getContext()); + // A note on division: onnx.Div from int, int -> int types rounds towards + // zero. The torch DivTensorOp actually doesn't allow returning an int dtype, + // but this was artificially plummbed through. Unfortunately, there is no + // scalar trunc div op in torch; however, we can safely assume all operands + // are positive so floor divide should be a sufficient scalar replacement. + patterns.insert< + PropagateAtenCatPattern, PropagateAtenIndexSelectPattern, + PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, + PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, + PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, + PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern, + PropagateAtenUnaryPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern>( + patterns.getContext()); +} + +void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { + patterns.insert, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern, + RemoveUnusedPattern>( + patterns.getContext()); +} + +} // namespace namespace { class ScalarizeShapesPass : public ScalarizeShapesBase { public: @@ -563,25 +1540,74 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns - .insert, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern>(context); + // populate patterns + populateScalarizationPropagationPatterns(patterns); + populateScalarizationFoldPatterns(patterns); + populateScalarizationCanonicalizePatterns(patterns); + populateScalarizationRemovePatterns(patterns); context->getLoadedDialect() ->getCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + // don't load torch canonicalization patterns, since these may lead to + // issues with propagation + + // walk func op bottom-up to collect a SetVector of shape-related operations + // When we pass this SetVector to the pattern rewrite driver, it will + // process the operations top-down, thereby propagating scalarization + // starting from sources. + auto funcOp = getOperation(); + llvm::SetVector shapeCalculationOps; + funcOp.walk( + [&](Operation *op) { + // Walking bottom-up, start adding ops when we reach an anchor point + // (a prim list of ints) + if (isAnchorOp(op)) { + shapeCalculationOps.insert(op); + return; + } + // add view ops for now until the decompositions for flatten and + // unflatten are removed. + if (isa(op)) { + shapeCalculationOps.insert(op); + return; + } + // Insert the op if any of it's consumers have already been identified + // as a shape calculation op. To avoid adding the producer of + // something like a size.int op, don't add ops when their consumer is + // a source op for shape scalarization. Here is some sample IR: + // ------ + // %0 = aten.matmul %arg0, %arg1 : ... -> !torch.vtensor<[?,?,?],f32> + // %1 = aten.size.int %0, %int0 : !torch.int + // %2 = prim.ListConstruct %1 : (!torch.int) -> !torch.list + // return %2 : !torch.list + // ------ + // In this example, don't add the matmul (%0), or it's producers, to + // shapeCalculationOps. It's consumer (%1) is indeed a shape + // calculation op, but the size.int op is an elementary unit of shape + // computation. No futher gathering of producers is necessary to + // reduce this. Similarly, don't always add the `self` of a view op. + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (shapeCalculationOps.contains(userOp) && + !isSourceOpForShapeScalarization(userOp) && + !isInvalidValidViewConsumer(userOp, shapeCalculationOps)) { + shapeCalculationOps.insert(op); + return; + } + } + }); + + GreedyRewriteConfig config; + // When propagating, we need to go back and clean up aten.Tensor ops that + // have been futher propagated. It is also necessary to add newly created + // ops for custom folding after scalarizing a where.self op. + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + if (failed(applyOpPatternsGreedily(shapeCalculationOps.getArrayRef(), + std::move(patterns), config))) { return signalPassFailure(); } + + // TODO: Warn when failing to process operations in the worklist. } }; } // namespace diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index f1ebeb307976..d599fd5369f4 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -32,9 +32,6 @@ class FoldPrimUncheckedCastOp : public OpRewritePattern { } // namespace namespace { -// TODO: Only unroll inside the shape calculation region. -// Maybe do this by only applying patterns and folding greedily on the ops -// inside the region + the shape.calculate op itself? class FullyUnrollPrimLoopOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -42,6 +39,12 @@ class FullyUnrollPrimLoopOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); + // Only unroll loops if they are contained in a shape calculate region. + Region *region = op->getParentRegion(); + Operation *parentOp = region->getParentOp(); + if (!parentOp || !isa(parentOp)) + return rewriter.notifyMatchFailure( + op, "Loop is not contained in a shape calculation region."); if (!op.isForLike()) return rewriter.notifyMatchFailure(op, "Loop is not for-like"); int64_t maxTripCount; diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 6b18af04dca6..0935af83a803 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -179,11 +179,10 @@ class RefineNumToTensorScalarOpType "should have concrete Scalar Type."); } Type inputType = getBuiltInTypeForTorchScalar(op.getA().getType()); - auto impliedTypeFromInputType = + auto impliedTypeFromInputType = cast( cast(originalResultType) .getWithSizesAndDtype(originalResultType.getOptionalSizes(), - inputType) - .cast(); + inputType)); op.getResult().setType(impliedTypeFromInputType); return success(); @@ -214,8 +213,8 @@ class SimplifyDtypeCalculationsPass GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 37ce829cb731..a2d2c6450693 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern { }; } // namespace +namespace { +class InferTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTensorOp op, + PatternRewriter &rewriter) const override { + auto context = op.getContext(); + auto loc = op.getLoc(); + auto result = op.getResult(); + auto resultType = cast(result.getType()); + if (resultType.hasSizes() && resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "The result of aten.tensor is already a BaseTensorType."); + } + + auto inputList = op.getOperand(0); + auto listConstruct = inputList.getDefiningOp(); + if (!listConstruct) { + return rewriter.notifyMatchFailure( + op, "The operand 0 of aten.tensor is not PrimListConstructOp."); + } + + // Currently only support the 1d input list. + SmallVector sizes; + sizes.push_back(listConstruct->getOperands().size()); + FailureOr torchType; + auto eleType = listConstruct->getOperands()[0].getType(); + if (isa(eleType)) { + torchType = getTypeForScalarType(op->getContext(), + torch_upstream::ScalarType::Long); + } else if (isa(eleType)) { + torchType = getTypeForScalarType(op->getContext(), + torch_upstream::ScalarType::Float); + } else { + return rewriter.notifyMatchFailure( + op, "Currently only support Int and Float Type."); + } + auto newResultType = ValueTensorType::get(context, sizes, *torchType); + + Value originalTypedValue; + for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) { + if (!originalTypedValue) { + rewriter.setInsertionPointAfter(op); + originalTypedValue = + rewriter.create(loc, resultType, result); + } + use.set(originalTypedValue); + } + + result.setType(newResultType); + + return success(); + } +}; +} // namespace + static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, PatternRewriter &rewriter) { @@ -97,11 +153,10 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, } auto originalResultType = cast(result.getType()); - auto impliedTypesFromShape = + auto impliedTypesFromShape = cast( cast(originalResultType) .getWithSizesAndDtype(ArrayRef(sizes), - originalResultType.getOptionalDtype()) - .cast(); + originalResultType.getOptionalDtype())); return updateCalculateOpResultTypes(op, resultNum, impliedTypesFromShape, rewriter); @@ -136,20 +191,22 @@ class SimplifyShapeCalculationsPass populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); AtenSizeOp::getCanonicalizationPatterns(patterns, context); AtenLenTOp::getCanonicalizationPatterns(patterns, context); AtenAddTOp::getCanonicalizationPatterns(patterns, context); + AtenSliceTOp::getCanonicalizationPatterns(patterns, context); // TODO: Debug visitation order to make this more efficient. // A single linear scan should suffice. GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Utils/CMakeLists.txt b/lib/Dialect/Torch/Utils/CMakeLists.txt index 91088078891d..45b3e1b987aa 100644 --- a/lib/Dialect/Torch/Utils/CMakeLists.txt +++ b/lib/Dialect/Torch/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(TorchMLIRTorchUtils Utils.cpp + SparsityUtils.cpp TorchUpstream.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/Torch/Utils/SparsityUtils.cpp b/lib/Dialect/Torch/Utils/SparsityUtils.cpp new file mode 100644 index 000000000000..985316261b58 --- /dev/null +++ b/lib/Dialect/Torch/Utils/SparsityUtils.cpp @@ -0,0 +1,50 @@ +//===----------------------------------------------------------------------===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +FailureOr Torch::getSparsityWithDenseLTAtDim(Attribute attr, + Value dim) { + if (!attr) + return Attribute(); + + auto enc = cast(attr); + int64_t dimInt = 0; + int64_t rank = enc.getDimRank() + 1; + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, rank); + if (!isValidDim(dimInt, rank)) { + return failure(); + } + if (!enc.isIdentity()) { + // TODO: support block sparsity and permutation (CSC). + return failure(); + } + auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true); + SmallVector lvlTps = llvm::to_vector(enc.getLvlTypes()); + lvlTps.insert(lvlTps.begin() + dimInt, denseLT); + auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext()); + return SparseTensorEncodingAttr::get( + enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(), + enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal()); + } + // Do not know how to handle dynamic dimension. + return failure(); +} diff --git a/lib/Dialect/Torch/Utils/TorchUpstream.cpp b/lib/Dialect/Torch/Utils/TorchUpstream.cpp index 2dce14ef964c..0136ed0f0892 100644 --- a/lib/Dialect/Torch/Utils/TorchUpstream.cpp +++ b/lib/Dialect/Torch/Utils/TorchUpstream.cpp @@ -21,7 +21,7 @@ static inline bool isQIntType(ScalarType t) { // Don't forget to extend this when adding new QInt types return t == ScalarType::QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2 || - t == ScalarType::QUInt2x4; + t == ScalarType::QUInt2x4 || t == ScalarType::QInt16; } //===----------------------------------------------------------------------===// @@ -128,6 +128,21 @@ ScalarType result_type(const ResultTypeState &in_state) { combine_categories(in_state.zeroResult, in_state.wrappedResult)); } +Reduction get_loss_reduction_enum(const llvm::StringRef &reduce) { + if (reduce == "none") { + return torch_upstream::Reduction::None; + } else if (reduce == "mean") { + return torch_upstream::Reduction::Mean; + } else if (reduce == "sum") { + return torch_upstream::Reduction::Sum; + } else if (reduce == "end") { + return torch_upstream::Reduction::END; + } else { + llvm_unreachable( + "'reduction' argument must be either none, mean, sum or end"); + } +} + ReductionType get_reduction_enum(const llvm::StringRef &reduce) { if (reduce == "max" || reduce == "amax") { return torch_upstream::ReductionType::MAX; diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d634556c98a1..489ef3b64478 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -9,8 +9,10 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" using namespace mlir; using namespace mlir::torch; @@ -35,6 +37,18 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { return dim; } +Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc, + ArrayRef cstInput) { + SmallVector cstValues; + for (int64_t i : cstInput) { + cstValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + loc, Torch::ListType::get(IntType::get(rewriter.getContext())), + cstValues); +} + bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) @@ -43,6 +57,49 @@ bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { return true; } +Value Torch::toTorchList(Location loc, PatternRewriter &rewriter, + ArrayRef vals) { + SmallVector intConsts; + for (int64_t v : vals) { + intConsts.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(v))); + } + + auto listType = + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())); + return rewriter.create(loc, listType, intConsts); +} + +TypedValue Torch::broadcastTo(Location loc, + PatternRewriter &rewriter, + Value val, + ArrayRef newShape) { + + auto ty = dyn_cast(val.getType()); + assert(ty); + auto newTy = ty.getWithSizesAndDtype(newShape, ty.getOptionalDtype()); + return cast>( + rewriter + .create(loc, newTy, val, + toTorchList(loc, rewriter, newShape)) + .getResult()); +} + +TypedValue Torch::reshapeTo(Location loc, + PatternRewriter &rewriter, + Value val, + ArrayRef newShape) { + + auto ty = dyn_cast(val.getType()); + assert(ty); + auto newTy = ty.getWithSizesAndDtype(newShape, ty.getOptionalDtype()); + return cast>( + rewriter + .create(loc, newTy, val, + toTorchList(loc, rewriter, newShape)) + .getResult()); +} + torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (isa(type)) return torch_upstream::ScalarType::Float; @@ -68,6 +125,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::QUInt8; if (isa(type)) return torch_upstream::ScalarType::QInt8; + if (isa(type)) + return torch_upstream::ScalarType::QInt16; if (isa(type)) return torch_upstream::ScalarType::QInt32; if (isa(type)) { @@ -79,7 +138,36 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { if (complexElemType.isF64()) return torch_upstream::ScalarType::ComplexDouble; } - llvm::report_fatal_error("unhandled type for getScalarTypeForType"); + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fn; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e5m2fnuz; + if (isa(type)) + return torch_upstream::ScalarType::Float8_e4m3fnuz; + std::string errorMsg = "Unhandled type in getScalarTypeForType: "; + llvm::raw_string_ostream os(errorMsg); + type.print(os); + // os << "\nType ID: " << type.getTypeID(); + os << "\nType properties:"; + os << "\n Is integer: " << (type.isInteger() ? "yes" : "no"); + os << "\n Is float: " + << (type.isIntOrFloat() && !type.isInteger() ? "yes" : "no"); + os << "\n Is index: " << (type.isIndex() ? "yes" : "no"); + os << "\n Bit width: " + << (type.isIntOrFloat() ? std::to_string(type.getIntOrFloatBitWidth()) + : "N/A"); + os << "\n Is signless: " << (type.isSignlessInteger() ? "yes" : "no"); + os << "\n Is signed: " << (type.isSignedInteger() ? "yes" : "no"); + // special error message for unsigned integer + if (type.isUnsignedInteger()) { + os << "\n Is unsigned: yes"; + os << "\nUnsigned integer support is currently spotty. Please seeheck " + "https://github.com/llvm/torch-mlir/issues/3720 " + "for more details."; + } + llvm::report_fatal_error(llvm::StringRef(errorMsg)); } Type Torch::getTypeForTorchType( MLIRContext *context, Type type, @@ -108,9 +196,9 @@ Torch::getTypeForScalarType(MLIRContext *context, case torch_upstream::ScalarType::Bool: return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: - return mlir::FloatType::getBF16(context); + return mlir::BFloat16Type::get(context); case torch_upstream::ScalarType::Half: - return mlir::FloatType::getF16(context); + return mlir::Float16Type::get(context); case torch_upstream::ScalarType::Byte: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: @@ -119,6 +207,8 @@ Torch::getTypeForScalarType(MLIRContext *context, return QUInt8Type::get(context); case torch_upstream::ScalarType::QInt8: return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt16: + return QInt16Type::get(context); case torch_upstream::ScalarType::QInt32: return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: @@ -127,6 +217,14 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::ComplexType::get(Float32Type::get(context)); case torch_upstream::ScalarType::ComplexDouble: return mlir::ComplexType::get(Float64Type::get(context)); + case torch_upstream::ScalarType::Float8_e5m2: + return Float8E5M2Type::get(context); + case torch_upstream::ScalarType::Float8_e4m3fn: + return Float8E4M3FNType::get(context); + case torch_upstream::ScalarType::Float8_e5m2fnuz: + return Float8E5M2FNUZType::get(context); + case torch_upstream::ScalarType::Float8_e4m3fnuz: + return Float8E4M3FNUZType::get(context); case torch_upstream::ScalarType::Undefined: return failure(); default: @@ -208,21 +306,35 @@ std::optional Torch::getTensorRank(Value tensor) { return tensorType.getSizes().size(); } +std::optional Torch::getTensorNumel(Value tensor) { + BaseTensorType tensorType = cast(tensor.getType()); + if (!tensorType.hasSizes()) + return std::nullopt; + int64_t numel = 1; + for (auto dim : tensorType.getSizes()) { + if (dim == ShapedType::kDynamic) + return ShapedType::kDynamic; + numel *= dim; + } + return numel; +} + bool Torch::isViewLikeOp(Operation *op) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value // semantics. - return isa(op); + AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, @@ -275,6 +387,32 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { return updatedShape; } +ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef shapes, + Type dtype) { + assert(!shapes.empty() && "shape vector cannot be empty"); + SmallVector shapeInts; + for (Value shape : shapes) { + int64_t dim; + if (matchPattern(shape, m_TorchConstantInt(&dim))) + shapeInts.push_back(dim); + else + shapeInts.push_back(kUnknownSize); + } + return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype); +} + +// Helper function to get the size of the tensor at the given dimension. +Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor, + int64_t dim) { + auto loc = tensor.getLoc(); + auto dimVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + // Use 'createOrFold' instead of 'create': + // If the dimension is a constant, then the AtenSizeIntOp is folded to a + // ContantIntOp. + return rewriter.createOrFold(loc, tensor, dimVal); +} + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, @@ -318,6 +456,11 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(op, "input tensor must have size"); } + FailureOr enc = + getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim); + if (failed(enc)) { + return failure(); + } SmallVector unsqueezedShape; ArrayRef inputShape = inputType.getSizes(); @@ -334,8 +477,8 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, } else { unsqueezedShape.resize(unsqueezedRank, kUnknownSize); } - Type unsqueezedType = inputType.getWithSizesAndDtype( - unsqueezedShape, inputType.getOptionalDtype()); + Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity( + unsqueezedShape, inputType.getOptionalDtype(), enc.value()); Value unsqueezed = rewriter.create( op->getLoc(), unsqueezedType, input, dim); return unsqueezed; @@ -525,6 +668,24 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, return success(); } +LogicalResult Torch::getPermutedType(BaseTensorType inType, + SmallVector permuteDims, + Type &permutedType) { + if (!inType.hasSizes()) + return failure(); + + SmallVector shape(inType.getSizes()); + if (shape.size() != permuteDims.size()) + return failure(); + + SmallVector permutedShape; + for (unsigned i = 0; i < shape.size(); i++) + permutedShape.push_back(shape[permuteDims[i]]); + permutedType = inType.getWithSizesAndDtype(llvm::ArrayRef(permutedShape), + inType.getOptionalDtype()); + return success(); +} + Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { if (inputType.isF16()) return rewriter.getF32Type(); @@ -534,23 +695,22 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getF32Type(); if (isa(inputType)) return rewriter.getF64Type(); - if (inputType.isFloat8E5M2()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FN()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E5M2FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isFloat8E4M3FNUZ()) + if (isa(inputType)) return rewriter.getF32Type(); - if (inputType.isSignedInteger(8)) - return rewriter.getI64Type(); - if (inputType.isUnsignedInteger(8)) - return rewriter.getI64Type(); - if (inputType.isSignedInteger(16)) + if (inputType.isInteger(8)) + // this is an intentional deviation from CUDA (which accumulates i8 to i64) + return rewriter.getI32Type(); + if (inputType.isInteger(16)) return rewriter.getI64Type(); - if (inputType.isSignedInteger(32)) + if (inputType.isInteger(32)) return rewriter.getI64Type(); - if (inputType.isSignedInteger(64)) + if (inputType.isInteger(64)) return rewriter.getI64Type(); return inputType; } diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 4b89b8da1d6b..5d9122fd7bc6 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -15,7 +15,6 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp index a81c27d92845..3a667b81d942 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionOps.cpp @@ -9,10 +9,8 @@ #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "llvm/ADT/StringMap.h" @@ -25,7 +23,18 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) { if (lhs.hasRank() != rhs.hasRank()) return false; bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true; - bool sameElementType = lhs.getElementType() == rhs.getElementType(); + bool sameElementType = false; + // Namely, it is worth mentioning that the backends can have different + // expectations for signedness when converting from and to the builtin MLIR + // types. Therefore, the verifier cannot expect the input and output types to + // match in their signedness. + if (isa(lhs.getElementType()) && + isa(rhs.getElementType())) { + sameElementType = lhs.getElementType().getIntOrFloatBitWidth() == + rhs.getElementType().getIntOrFloatBitWidth(); + } else { + sameElementType = lhs.getElementType() == rhs.getElementType(); + } return sameElementType && sameSize; } @@ -44,18 +53,6 @@ LogicalResult ToBuiltinTensorOp::verify() { return success(); } -LogicalResult ToBuiltinTensorOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto resultType = - cast(operands[0].getType()).toBuiltinTensor(); - if (!resultType) - return failure(); - inferredReturnTypes.push_back(resultType); - return success(); -} - //===----------------------------------------------------------------------===// // FromBuiltinTensorOp //===----------------------------------------------------------------------===// @@ -76,7 +73,7 @@ LogicalResult FromBuiltinTensorOp::verify() { //===----------------------------------------------------------------------===// OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -89,7 +86,7 @@ OpFoldResult FromI1Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -102,7 +99,7 @@ OpFoldResult ToI1Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -115,7 +112,7 @@ OpFoldResult FromI64Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -128,7 +125,7 @@ OpFoldResult ToI64Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { @@ -141,7 +138,7 @@ OpFoldResult ToF64Op::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult FromF64Op::fold(FoldAdaptor adaptor) { - auto attr = adaptor.getOperand().dyn_cast_or_null(); + auto attr = dyn_cast_or_null(adaptor.getOperand()); if (attr) { return attr; } else { diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 947011ea8338..53de48f21934 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -23,22 +23,22 @@ void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects( // Type conversion setup. //===----------------------------------------------------------------------===// -static void -setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, - TypeConverter &typeConverter) { +using ValueTensorTypeConversionFn = + std::function(Torch::ValueTensorType)>; + +static void setupValueTensorToBuiltinTensorConversion( + ConversionTarget &target, TypeConverter &typeConverter, + const ValueTensorTypeConversionFn &conversionFn) { target.addLegalOp(); - typeConverter.addConversion( - [](Torch::ValueTensorType type) -> std::optional { - return type.toBuiltinTensor(); - }); + typeConverter.addConversion(conversionFn); typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); if (!isa(inputs[0].getType())) return {}; - return builder.create(loc, inputs[0]); + return builder.create(loc, type, inputs[0]); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::ValueTensorType type, @@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 1 && type.isSignless())) - return std::nullopt; - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 1 && type.isSignless())) + return Value(); + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return Value(); + // Other input type to be converted to i64 are handled by other + // materializers. + if (!isa(inputs[0].getType())) + return Value(); + assert(inputs.size() == 1); + return builder.createOrFold(loc, inputs[0]); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, Float64Type type, ValueRange inputs, - Location loc) -> std::optional { - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + Float64Type type, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -137,19 +137,19 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, [](Torch::GeneratorType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization( - [](OpBuilder &builder, IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + IntegerType type, ValueRange inputs, + Location loc) -> Value { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return Value(); + // Other input type to be converted to i64 are handled by other + // materializers. + if (!isa(inputs[0].getType())) + return Value(); + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -162,9 +162,54 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, void mlir::torch::TorchConversion::setupBackendTypeConversion( ConversionTarget &target, TypeConverter &typeConverter) { - setupValueTensorToBuiltinTensorConversion(target, typeConverter); + auto valueTensorTypeConversion = + [](Torch::ValueTensorType type) -> std::optional { + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert any integer type to signless + if (type.getDtype().isInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; + }; + setupValueTensorToBuiltinTensorConversion(target, typeConverter, + valueTensorTypeConversion); + setupTorchBoolToI1Conversion(target, typeConverter); + setupTorchIntToI64Conversion(target, typeConverter); + setupTorchFloatToF64Conversion(target, typeConverter); + setupTorchGeneratorToI64Conversion(target, typeConverter); +} + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo( + ConversionTarget &target, TypeConverter &typeConverter) { + auto valueTensorTypeConversion = + [](Torch::ValueTensorType type) -> std::optional { + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert signed integer type to signless, keep unsigned as unsigned + if (type.getDtype().isUnsignedInteger()) { + return builtinType.clone(type.getDtype()); + } else if (type.getDtype().isSignedInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; + }; + setupValueTensorToBuiltinTensorConversion(target, typeConverter, + valueTensorTypeConversion); setupTorchBoolToI1Conversion(target, typeConverter); setupTorchIntToI64Conversion(target, typeConverter); setupTorchFloatToF64Conversion(target, typeConverter); setupTorchGeneratorToI64Conversion(target, typeConverter); } +#endif diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 896dd9577617..dadd865a54a7 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -9,11 +9,12 @@ #include "PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" @@ -27,6 +28,51 @@ using namespace mlir::torch::TorchConversion; //===----------------------------------------------------------------------===// namespace { + +// TODO: Consider upstreaming this to an `arith::ExtFOp` folder: +struct ExtFTruncFPattern : public OpRewritePattern { + ExtFTruncFPattern(MLIRContext *context) : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(arith::TruncFOp truncf, + PatternRewriter &rewriter) const override { + Value operand = truncf.getOperand(); + auto extf = operand.getDefiningOp(); + if (!extf) + return failure(); + + auto parentOperand = extf.getOperand(); + if (truncf.getType() != parentOperand.getType()) + return failure(); + + rewriter.replaceOp(truncf, parentOperand); + return success(); + } +}; + +void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target) { + populateFunctionOpInterfaceTypeConversionPattern(patterns, + typeConverter); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return typeConverter.isSignatureLegal(op.getFunctionType()) && + typeConverter.isLegal(&op.getBody()); + }); + populateCallOpTypeConversionPattern(patterns, typeConverter); + target.addDynamicallyLegalOp( + [&](func::CallOp op) { return typeConverter.isLegal(op); }); + + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + target.addLegalOp(); + + target.markUnknownOpDynamicallyLegal([&](Operation *op) { + return isNotBranchOpInterfaceOrReturnLikeOp(op) || + isLegalForBranchOpInterfaceTypeConversionPattern(op, + typeConverter) || + isLegalForReturnOpTypeConversionPattern(op, typeConverter); + }); +} + struct FuncBackendTypeConversionPass : public FuncBackendTypeConversionBase { using FuncBackendTypeConversionBase< @@ -44,31 +90,41 @@ struct FuncBackendTypeConversionPass typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - populateCallOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](func::CallOp op) { return typeConverter.isLegal(op); }); - - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - populateReturnOpTypeConversionPattern(patterns, typeConverter); - target.addLegalOp(); - - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return isNotBranchOpInterfaceOrReturnLikeOp(op) || - isLegalForBranchOpInterfaceTypeConversionPattern(op, - typeConverter) || - isLegalForReturnOpTypeConversionPattern(op, typeConverter); - }); + populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target); if (failed(applyFullConversion(module, target, std::move(patterns)))) signalPassFailure(); } }; + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +struct FuncBackendTypeConversionForStablehloPass + : public FuncBackendTypeConversionForStablehloBase< + FuncBackendTypeConversionForStablehloPass> { + using FuncBackendTypeConversionForStablehloBase< + FuncBackendTypeConversionForStablehloPass>:: + FuncBackendTypeConversionForStablehloBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversionForStablehlo(target, + typeConverter); + + populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; +#endif // TORCH_MLIR_ENABLE_STABLEHLO } // namespace std::unique_ptr> @@ -76,6 +132,13 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { return std::make_unique(); } +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +std::unique_ptr> mlir::torch::TorchConversion:: + createFuncBackendTypeConversionForStablehloPass() { + return std::make_unique(); +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO + //===----------------------------------------------------------------------===// // FinalizingBackendTypeConversionPass //===----------------------------------------------------------------------===// @@ -148,6 +211,55 @@ struct FinalizingBackendTypeConversionPass typeConverter.addConversion([](Type type) { return type; }); TorchConversion::setupBackendTypeConversion(target, typeConverter); + // Mark materializations as illegal in this pass (since we are finalizing) + // and add patterns that eliminate them. + setupFinalization(target, patterns, typeConverter); + + // If all result types are legal, and all block arguments are legal, then + // all types in the program are legal. + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents the patterns from updating + // the types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + + RewritePatternSet greedyPatterns(context); + greedyPatterns.insert(context); + if (failed(applyPatternsGreedily(func, std::move(greedyPatterns)))) + signalPassFailure(); + + // Drop attributes that are no longer used after conversion out of Torch. + stripTorchAttrs(func); + } +}; + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +struct FinalizingBackendTypeConversionForStablehloPass + : public FinalizingBackendTypeConversionForStablehloBase< + FinalizingBackendTypeConversionForStablehloPass> { + using FinalizingBackendTypeConversionForStablehloBase< + FinalizingBackendTypeConversionForStablehloPass>:: + FinalizingBackendTypeConversionForStablehloBase; + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversionForStablehlo(target, + typeConverter); + // Mark materializations as illegal in this pass (since we are finalizing) // and add patterns that eliminate them. setupFinalization> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +std::unique_ptr> mlir::torch:: + TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() { + return std::make_unique(); +} +#endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 36292a0f0570..5c30889c45a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -13,11 +13,9 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 5209e6683db3..f80570b30e41 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -9,11 +9,7 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/Conversion/Passes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/Passes.h" -#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" @@ -22,17 +18,20 @@ #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "stablehlo/transforms/Passes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" + +#ifdef TORCH_MLIR_ENABLE_TOSA +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +using namespace mlir::tosa; +#endif using namespace mlir; using namespace mlir::torch; -using namespace mlir::tosa; //===----------------------------------------------------------------------===// // Pass registration @@ -45,17 +44,19 @@ namespace reg { void mlir::torch::registerTorchConversionPasses() { reg::registerPasses(); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration< + TorchConversion::TorchBackendToLinalgOnTensorsBackendPipelineOptions>( "torch-backend-to-linalg-on-tensors-backend-pipeline", "Pipeline lowering torch backend contract to linalg-on-tensors backend " "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - +#ifdef TORCH_MLIR_ENABLE_TOSA mlir::PassPipelineRegistration<>( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", TorchConversion::createTorchBackendToTosaBackendPipeline); +#endif #ifdef TORCH_MLIR_ENABLE_STABLEHLO mlir::PassPipelineRegistration< TorchConversion::StablehloBackendPipelineOptions>( @@ -67,10 +68,14 @@ void mlir::torch::registerTorchConversionPasses() { } void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( - OpPassManager &pm) { + OpPassManager &pm, + const TorchBackendToLinalgOnTensorsBackendPipelineOptions &options) { + // Fix non constant dims passed to reduction ops + pm.addNestedPass( + torch::Torch::createRestructureNonConstantAxesPass()); + // We want to fuse quantized operations together before lowering to linalg. pm.addNestedPass(Torch::createFuseQuantizedOpsPass()); - pm.addNestedPass(Torch::createScalarizeShapesPass()); // Lower to linalg + guards which is the input to codegen backends. // We do this first as it tends to involve pattern-matching against constants, @@ -83,7 +88,8 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); pm.addNestedPass(createConvertTorchToTensorPass()); - pm.addPass(createConvertTorchConversionToMLProgramPass()); + if (options.useMlprogram) + pm.addPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); // Clean up any non-canonical code introduced above.. @@ -105,9 +111,12 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( // Verify that we have lowered to the form that linalg on tensors backends // expect. This fails compilation (signalPassFailure) if the IR is not in the // correct form. - pm.addPass(TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); + if (options.verify) + pm.addPass( + TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()); } +#ifdef TORCH_MLIR_ENABLE_TOSA void TorchConversion::createTorchBackendToTosaBackendPipeline( OpPassManager &pm) { pm.addNestedPass(createConvertTorchToTosaPass()); @@ -131,6 +140,7 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( // correct form. pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); } +#endif #ifdef TORCH_MLIR_ENABLE_STABLEHLO void TorchConversion::createTorchBackendToStablehloBackendPipeline( @@ -152,10 +162,11 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline( // Finish the type conversion from `torch` types to the types of the // StableHLO backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addPass( + TorchConversion::createFuncBackendTypeConversionForStablehloPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); + TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()); // Verify that we have lowered to Stablehlo ops. pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); @@ -169,5 +180,12 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline( pm.addNestedPass( stablehlo::createStablehloCanonicalizeDynamismPass()); pm.addNestedPass(createCanonicalizerPass()); + + // Legalize deprecated ops to Stablehlo ops + stablehlo::StablehloLegalizeDeprecatedOpsPassOptions stablehloOptions; + stablehloOptions.failOnUnusedOps = false; + pm.addNestedPass( + stablehlo::createStablehloLegalizeDeprecatedOpsPass(stablehloOptions)); + pm.addPass(createCanonicalizerPass()); } #endif diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 1e6879530ce6..1b7360e14a7f 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -104,7 +104,8 @@ class UnpackQuantizedMatmulWeights char mask = (1 << unpackedBitWidth) - 1; for (int b = 0; b < packRatio; b++) { newData[i * packRatio + b] = - APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b)); + APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b), + /*isSigned=*/false, /*implicitTrunc=*/true); mask = mask << unpackedBitWidth; } } @@ -130,8 +131,7 @@ class UnpackQuantTensorPass RewritePatternSet patterns(context); patterns.add(context); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index e8789a05a3be..5189a17fc942 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -10,7 +10,6 @@ #include "PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -20,10 +19,8 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" -#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 0c8cdf2fc54d..3ff6e4732db2 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -11,10 +11,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -47,13 +46,21 @@ class VerifyStablehloBackendContractPass // Structural operations. target.addDynamicallyLegalOp( opHasLegalTypes); - // Shape operations. - target.addDynamicallyLegalOp(opHasLegalTypes); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + auto moduleOp = getOperation(); + RewritePatternSet patterns(context); + if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { + emitError(moduleOp.getLoc()) + << "Module does not conform to the Stablehlo backend contract. " + "See dialect conversion legality information above."; + return signalPassFailure(); + } } }; } // namespace diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp index a29e14a3d705..efa40a02aeb0 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -6,15 +6,13 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// - +#ifdef TORCH_MLIR_ENABLE_TOSA #include "PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" @@ -65,3 +63,4 @@ std::unique_ptr> mlir::torch::TorchConversion::createVerifyTosaBackendContractPass() { return std::make_unique(); } +#endif // TORCH_MLIR_ENABLE_TOSA diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index e8f9622c3088..d9096929e3bb 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -16,8 +16,9 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/IR/Dialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h" #include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h" @@ -34,12 +35,20 @@ #include "stablehlo/transforms/Passes.h" #endif +#ifdef TORCH_MLIR_ENABLE_TOSA +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#endif + void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); registry.insert(); +} + +void mlir::torch::registerAllExtensions(mlir::DialectRegistry ®istry) { mlir::func::registerInlinerExtension(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); } // TODO: Break this up when backends are separated. @@ -47,7 +56,11 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, sparse_tensor::SparseTensorDialect, + tensor::TensorDialect>(); +#ifdef TORCH_MLIR_ENABLE_TOSA + registry.insert(); +#endif } void mlir::torch::registerAllPasses() { @@ -61,6 +74,9 @@ void mlir::torch::registerAllPasses() { mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); mlir::stablehlo::registerStablehloAggressiveSimplificationPass(); mlir::stablehlo::registerStablehloRefineShapesPass(); + mlir::stablehlo::registerStablehloConvertToSignlessPass(); + mlir::stablehlo::registerShapeLegalizeToStablehloPass(); + mlir::stablehlo::registerStablehloLegalizeDeprecatedOpsPass(); #endif #ifdef TORCH_MLIR_ENABLE_REFBACKEND diff --git a/lib/RefBackend/CMakeLists.txt b/lib/RefBackend/CMakeLists.txt index a8ed0439d815..b62da2954966 100644 --- a/lib/RefBackend/CMakeLists.txt +++ b/lib/RefBackend/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_library(TorchMLIRRefBackend DEPENDS MLIRTorchTypesIncGen TorchMLIRRefBackendPassIncGen + MLIRTorchConversionOpsIncGen LINK_COMPONENTS Core diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 3bd16ed38940..d40d02d43ffc 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -56,7 +56,7 @@ void mlir::torch::RefBackend::registerRefBackendPasses() { ::registerPasses(); } static bool isArgMemRefTypeValid(Type type) { if (auto memRefType = dyn_cast(type)) { Type elemTy = memRefType.getElementType(); - if (elemTy.isa()) { + if (isa(elemTy)) { return true; } else if (auto integerTy = dyn_cast(elemTy)) { if (integerTy.isSignlessInteger(64)) @@ -70,7 +70,7 @@ static bool isArgMemRefTypeValid(Type type) { if (integerTy.isSignlessInteger(1)) return true; } else if (auto complexTy = dyn_cast(elemTy)) { - return complexTy.getElementType().isa(); + return isa(complexTy.getElementType()); } } return false; @@ -425,8 +425,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(&getContext()); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } @@ -448,8 +447,7 @@ class GeneralizeTensorConcat void runOnOperation() override { RewritePatternSet patterns(&getContext()); tensor::populateDecomposeTensorConcatPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } @@ -471,9 +469,8 @@ class GeneralizeTensorPad void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + patterns.insert(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 2e42e4fed3ba..04f81dac0446 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -252,6 +252,18 @@ std::vector compute_shape_native_group_norm( return shapes; } +std::vector +compute_shape_prod(const at::Tensor &self, + c10::optional dtype) { + if (dtype.has_value()) { + return {Shape(dtype.value(), {})}; + } + if (isIntegralType(self.scalar_type(), true)) { + return {Shape(c10::ScalarType::Long, {})}; + } + return {Shape(self.scalar_type(), {})}; +} + std::vector compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 1ec7aa43f538..d7d56e48df5f 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -7,6 +7,10 @@ import re import sys +import torch + +torch.device("cpu") + from torch_mlir_e2e_test.framework import run_tests from torch_mlir_e2e_test.reporting import report_results from torch_mlir_e2e_test.registry import GLOBAL_TEST_REGISTRY @@ -28,9 +32,6 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( RefBackendLinalgOnTensorsBackend, ) -from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import ( - LinalgOnTensorsOnnxBackend, -) from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import ( LinalgOnTensorsTosaBackend, ) @@ -41,10 +42,10 @@ from .xfail_sets import ( LINALG_XFAIL_SET, LINALG_CRASHING_SET, - MAKE_FX_TOSA_PASS_SET, STABLEHLO_PASS_SET, STABLEHLO_CRASHING_SET, TOSA_PASS_SET, + TOSA_CRASHING_SET, LTC_XFAIL_SET, LTC_CRASHING_SET, TORCHDYNAMO_XFAIL_SET, @@ -55,6 +56,10 @@ FX_IMPORTER_CRASHING_SET, FX_IMPORTER_STABLEHLO_XFAIL_SET, FX_IMPORTER_STABLEHLO_CRASHING_SET, + FX_IMPORTER_TOSA_CRASHING_SET, + FX_IMPORTER_TOSA_XFAIL_SET, + ONNX_TOSA_XFAIL_SET, + ONNX_TOSA_CRASHING_SET, ) # Import tests to register them in the global registry. @@ -69,13 +74,14 @@ def _get_argparse(): "torchscript", "linalg", "stablehlo", - "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx", + "onnx_tosa", "fx_importer", "fx_importer_stablehlo", + "fx_importer_tosa", ] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument( @@ -95,6 +101,8 @@ def _get_argparse(): "onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path. "fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors. "fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend. +"fx_importer_tosa": run the model through the fx importer frontend and execute the graph using the TOSA backend. +"onnx_tosa": Import ONNX to Torch via the torch-onnx-to-torch path and execute the graph using the TOSA backend. """, ) parser.add_argument( @@ -154,11 +162,7 @@ def main(): elif args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET - crashing_set = set() - elif args.config == "make_fx_tosa": - config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend(), use_make_fx=True) - xfail_set = all_test_unique_names - MAKE_FX_TOSA_PASS_SET - crashing_set = set() + crashing_set = TOSA_CRASHING_SET elif args.config == "native_torch": config = NativeTorchTestConfig() xfail_set = set() @@ -179,14 +183,25 @@ def main(): config = FxImporterTestConfig(LinalgOnTensorsStablehloBackend(), "stablehlo") xfail_set = FX_IMPORTER_STABLEHLO_XFAIL_SET crashing_set = FX_IMPORTER_STABLEHLO_CRASHING_SET + elif args.config == "fx_importer_tosa": + config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa") + xfail_set = FX_IMPORTER_TOSA_XFAIL_SET + crashing_set = FX_IMPORTER_TOSA_CRASHING_SET elif args.config == "torchdynamo": - config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) + # TODO: Enanble runtime verification and extend crashing set. + config = TorchDynamoTestConfig( + RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False) + ) xfail_set = TORCHDYNAMO_XFAIL_SET crashing_set = TORCHDYNAMO_CRASHING_SET elif args.config == "onnx": - config = OnnxBackendTestConfig(LinalgOnTensorsOnnxBackend()) + config = OnnxBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = ONNX_XFAIL_SET crashing_set = ONNX_CRASHING_SET + elif args.config == "onnx_tosa": + config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa") + xfail_set = ONNX_TOSA_XFAIL_SET + crashing_set = ONNX_TOSA_CRASHING_SET do_not_attempt = set( args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or [] diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3fcb272f423e..b043d48ac2a0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -16,15 +16,71 @@ print(f"TORCH_VERSION_FOR_COMPARISON =", torch_version_for_comparison()) LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { - # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed - # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic", - "SplitWithSizes_Module_basic", + # lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec + # these interpolate tests are added specifically to test onnx.Resize. + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "InterpolateDynamicModule_scales_recompute_bilinear", + "ElementwiseFloatTensorGtIntTensorModule_basic", + "AtenIntMM_basic", + # unimplemented lowering torch -> linalg for torchvision.deform_conv2d + # this is added to check the torch.onnx.export -> import_onnx -> torch path + "DeformConv2D_basic", + "ReduceAnyDimFloatModule_basic", + "UnfoldModule_basic", + # _trilinear is an implementation of einsum, but sets dimensions to zero + # if a dimension is specified in all expand lists, and not in sumdim list. + # This is a bug in the implementation of _trilinear in PyTorch. + "Aten_TrilinearModuleZerodDimBug_basic", + # missing lowering from aten.pow.Tensor_Tensor for integer result + "PowIntIntModule_basic", + # Unknown builtin op: aten::_check_is_size in TorchScript + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } +if torch_version_for_comparison() < version.parse("2.5.0.dev"): + LINALG_XFAIL_SET = LINALG_XFAIL_SET | { + # Error: 'torch.aten.scaled_dot_product_attention' op expected 8 operands, but found 7 + # WORKS FOR TORCH VERSION 2.5.0.dev20240902, REMOVE WHEN ENABLE_GQA IS PUT IN STABLE + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", + } + LINALG_CRASHING_SET = { + # Runtime op verification: Out of bounds access + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "Aten_EmbeddingBagExample_basic", + # Runtime op verification: subview is out-of-bounds of the base memref + "Conv_Transpose1dModule_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose3dStaticModule_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeStrided_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", + # Runtime op verification: stride mismatch in memref.cast + "ReduceAllDimEmpty_basic", + "TraceUnsignedIntModule_empty", + "TraceModule_empty", # Crashes due to copy to a smaller destination buffer than the source buffer. "SliceCopyStartGreaterThanDimSize_Module_basic", } @@ -33,6 +89,7 @@ #### General TorchDynamo/PyTorch errors # torch._dynamo.exc.Unsupported: Tensor.item "CumsumModule_basic", + "CumprodModule_basic", # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 @@ -108,6 +165,7 @@ # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} "AddIntModule_basic", + "AddFloatIntModule_basic", "AtenIntTensorCharDtypeModule_basic", "BoolIntFalseModule_basic", "BoolIntTrueModule_basic", @@ -168,7 +226,6 @@ "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "IntFloatModule_basic", - "PowIntFloatModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len "LenStrModule_basic", @@ -239,11 +296,6 @@ "ScatterValueIntModule_basic", # AssertionError: Unregistered operation: torch.aten._unsafe_index_put "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", - # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed - # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - # AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu - "ScaledDotProductAttentionDifferentModule_basic", # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only "AtenEmbeddingBagStaticModule_basic", # Lowering not present for this case @@ -272,6 +324,12 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_depthwise", + "Conv2dQInt8Module_grouped", + "Conv2dQInt8Module_not_depthwise", + "Conv2dQInt8PerChannelModule_basic", + "Conv2dQInt8PerChannelModule_depthwise", + "Conv2dQInt8PerChannelModule_grouped", "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -331,31 +389,25 @@ } FX_IMPORTER_XFAIL_SET = { + "TimeOutModule_basic", # this test is expected to time out + "ReduceAnyDimFloatModule_basic", + "AddFloatIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "ArangeStartOutViewModule_basic", - "AtenEmbeddingBagStaticModule_basic", - "AtenEmbeddingBagSumExample_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", - "AtenItemFpOpModule_basic", - "AtenMatmulQMixedSigni8Transpose_basic", - "AtenMatmulQMixedSigni8_basic", - "AtenMatmulQint8MV_basic", - "AtenMatmulQint8_basic", - "AtenMatmulQint8VM_basic", - "AtenMatmulQint8VV_basic", - "AtenMmQMixedSigni8_basic", - "AtenMmQint8_basic", - "AtenMmQuint8_basic", + "AtenIntMM_basic", + "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", "QuantizedReluInt32_basic", "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", - "AtenSubFloatModule_basic", "BincountMinlengthModule_basic", "BincountModule_basic", "BincountStaticSizeModule_basic", @@ -367,62 +419,42 @@ "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", - "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", - "Conv2dQInt8Module_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "ConvTbcModule_basic", - "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", - "DivFloatModule_basic", + "CumprodModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "DeformConv2D_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "EqIntModule_basic", - "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - "FakeQuantizePerTensorAffineModule_basic", - "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "ExponentialModule_basic", "FloatImplicitModule_basic", "GeFloatIntModule_basic", - "GeFloatModule_basic", "GeIntModule_basic", "GtFloatIntModule_basic", - "GtIntModule_basic", "IntFloatModule_basic", "IntImplicitModule_basic", - "IsFloatingPointFloat_True", - "IsFloatingPointInt_False", "LenStrModule_basic", - "MaxPool3dCeilModeTrueModule_basic", - "MaxPool3dEmptyStrideStaticModule_basic", - "MaxPool3dLargeDatadModule_basic", - "MaxPool3dModuleRandomSimple_basic", - "MaxPool3dModule_basic", - "MaxPool3dStaticCeilModeTrueModule_basic", - "MaxPool3dStaticModule_basic", - "MulFloatModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", - "NeIntModule_basic", "NllLossModuleBackward1DMeanWeight_basic", "NllLossModuleBackward1DMean_basic", "NllLossModuleBackward1DSumWeight_basic", "NllLossModuleBackward1DSum_basic", "NllLossModuleBackward1DWeight_basic", "NllLossModuleBackward1D_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", + "PowIntIntModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -436,39 +468,214 @@ "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ScalarConstantTupleModule_basic", - "ScalarImplicitFloatModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", + "RepeatInterleaveFillModule_basic", + "RepeatInterleaveModule_basic", + "RepeatInterleaveStaticModule_basic", "SplitDimDynamicModule_basic", "SplitDimStaticModule_basic", - "SqrtIntConstantModule_basic", "SqrtIntModule_basic", - "SubFloatModule_basic", - "TModuleRank0_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", - "TensorToFloatZeroRank_basic", - "TensorToFloat_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "TensorsSplitTensorLastSmallerModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", "ThresholdBackward2dMixedModule_basic", - "TorchPrimLoopForLikeModule_basic", - "TorchPrimLoopWhileLikeModule_basic", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", - "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dDynamicFactor_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", + "ViewDtypeStaticModule_basic", + "WeightNormInterfaceModule_basic", + # Error: `aten.as_strided` op is not supported + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SplitWithSizes_Module_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "IsInfiniteModule_basic", + "InterpolateDynamicModule_sizes_nearest", + "IouOfModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "BernoulliFloatModule_basic", + "BernoulliTensorModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", + "ElementwiseSignbitModule_basic", + "ElementwiseCopysignModule_basic", + "BernoulliFloatModule_basic", + "BernoulliTensorModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleTop_basic", + "ReflectionPad3dModuleBottom_basic", + "ReflectionPad3dModuleLeft_basic", + "ReflectionPad3dModuleRight_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", } -FX_IMPORTER_CRASHING_SET = { +if torch_version_for_comparison() < version.parse("2.6.0.dev"): + # Passing on stable but failing on nightly + FX_IMPORTER_XFAIL_SET -= { + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "ExponentialModule_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SplitWithSizes_Module_basic", + "TensorsSplitTensorLastSmallerModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "BernoulliFloatModule_basic", + "BernoulliTensorModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "InterpolateDynamicModule_sizes_nearest", + "IouOfModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", + } + # Failing on stable but not on nightly + FX_IMPORTER_XFAIL_SET |= { + "AtenSubFloatModule_basic", + "Conv2dWithValidPaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "EqIntModule_basic", + "GeFloatModule_basic", + "GtIntModule_basic", + "NeIntModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "RsubInt0d_NumToTensor_Module_basic", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", + "SignAndLogarithmOfDeterminantModule_F32", + "SortIntListReverse_basic", + "SortIntList_basic", + "SqrtIntConstantModule_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", + "AtenItemFpOpModule_basic", + "DivFloatModule_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "MulFloatModule_basic", + "ScalarImplicitFloatModule_basic", + "SubFloatModule_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + } + +FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { "HBC_basic", + # Runtime op verification: out-of-bounds access + "_SoftmaxModule_basic", + "UpSampleNearest2dDynamicFactor_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + # Randomly mismatching values + "ConvolutionModule2DTranspose_basic", + # ? + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + # only on stable: mismatched number of results + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { + "AddFloatIntModule_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueModule_basic", + "AtenPolarDoubleModule_basic", + "AtenPolarFloatModule_basic", + "DiagonalWithStaticShapeModule_basic", + "EinsumStaticDiagonalDimensionModule_basic", + "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "MaxPool1dCeilModeTrueModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxUnpool3dModulePad0_basic", + "MaxUnpool3dModule_basic", + "MultinomialModule2D_F32", + "MultinomialModule2D_basic", + "MultinomialModule_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", + "ScatterAddStaticModule_basic", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", + "SignAndLogarithmOfDeterminantModule_F32", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopyStartGreaterThanDimSize_Module_basic", + "SliceCopy_Module_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "TimeOutModule_basic", + "WeightNormInterfaceModule_basic", "AdaptiveAvgPool3dDynamicNoBatch_basic", "AdaptiveAvgPool3dDynamic_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic", @@ -489,11 +696,6 @@ "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "ArangeStartOutViewModule_basic", - "ArgminIntModule_basic", - "ArgminIntModule_multiple_mins", - "ArgminModule_basic", - "ArgminModule_keepDim", - "ArgminModule_with_dim", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", @@ -503,8 +705,9 @@ "AtenDiagEmbedNonDefault4DDiag_basic", "AtenDiagEmbedOffsetDiag_basic", "AtenDiagEmbedRevDimDiag_basic", - "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", @@ -524,10 +727,9 @@ "AtenSubFloatModule_basic", "AtenTopKModule_basic", "AtenTopKSmallestModule_basic", - "AtenTrilModule_basic", - "AtenTrilWithNegDiagonalModule_basic", - "AtenTrilWithPosDiagonalModule_basic", "Aten_EmbeddingBagExample_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", "AvgPool2dDivisorOverrideModule_basic", "BernoulliTensorModule_basic", "BincountMinlengthModule_basic", @@ -545,12 +747,26 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_depthwise", + "Conv2dQInt8Module_grouped", + "Conv2dQInt8Module_not_depthwise", + "Conv2dQInt8PerChannelModule_basic", + "Conv2dQInt8PerChannelModule_depthwise", + "Conv2dQInt8PerChannelModule_grouped", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", + "DeformConv2D_basic", + "DeterminantBatchedModule_F32", + "DeterminantDynamicModule_F32", + "DeterminantModule_F32", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -560,34 +776,16 @@ "DiagonalModule_with_offset", "DivFloatModule_basic", "DivIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAtan2FloatIntModule_basic", - "ElementwiseAtan2TensorFloatModule_basic", - "ElementwiseAtan2TensorIntModule_basic", - "ElementwiseBitwiseLeftShiftInt32Module_basic", - "ElementwiseBitwiseLeftShiftInt64Module_basic", - "ElementwiseBitwiseLeftShiftInt8Module_basic", - "ElementwiseBitwiseRightShiftInt32Module_basic", - "ElementwiseBitwiseRightShiftInt64Module_basic", - "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", "ElementwiseReciprocalIntModule_basic", - "ElementwiseTanIntModule_basic", - "ElementwiseTanModule_basic", - "ElementwiseTernaryModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "EmptyModule_uint8", "EqIntModule_basic", - "FakeQuantizePerTensorAffineDynamicShapeModule_basic", - "FakeQuantizePerTensorAffineModule_basic", - "FakeQuantizePerTensorAffineRoundToEvenModule_basic", "Fill_TensorFloat32WithFloat32_basic", "Fill_TensorFloat32WithFloat64_basic", "Fill_TensorFloat32WithInt64_basic", @@ -635,29 +833,31 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", "IndexSelectRank0IdxModule_basic", "IndexTensorNegativeIndexModule_basic", "IntFloatModule_basic", "IntImplicitModule_basic", - "IsFloatingPointFloat_True", - "IsFloatingPointInt_False", "LenStrModule_basic", "MaxPool2dCeilModeTrueModule_basic", - "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic", "MaxPool2dWithIndicesBackwardStatic4DModule_basic", "MaxPool3dCeilModeTrueModule_basic", - "MaxPool3dEmptyStrideStaticModule_basic", - "MaxPool3dLargeDatadModule_basic", - "MaxPool3dModuleRandomSimple_basic", - "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", - "MaxPool3dStaticModule_basic", - "MseLossMeanReductionModule_basic", - "MseLossSumReductionWithDifferentElemTypeModule_basic", + "MaxPool3dWithIndicesAllNegativeValuesModule_basic", + "MaxPool3dWithIndicesAllOnesModule_basic", + "MaxPool3dWithIndicesCeilModeTrueModule_basic", + "MaxPool3dWithIndicesFullSizeKernelModule_basic", + "MaxPool3dWithIndicesModule_basic", + "MaxPool3dWithIndicesNonDefaultDilationModule_basic", + "MaxPool3dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool3dWithIndicesNonDefaultParamsModule_basic", + "MaxPool3dWithIndicesNonDefaultStrideModule_basic", + "MaxPool3dWithIndicesStaticModule_basic", "MulFloatModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", @@ -678,11 +878,8 @@ "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormalFunctionalModule_basic", - "NumToTensorFloatModule_basic", - "NumToTensorIntModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -700,17 +897,6 @@ "RandnLikeDtypeModule_basic", "RandnLikeModule_basic", "RandnModule_basic", - "ReduceAllDimBool_basic", - "ReduceAllDimEmpty_basic", - "ReduceAllDimFloat_basic", - "ReduceAllDimInt_basic", - "ReduceMaxAlongDimUnsignedInt_basic", - "ReduceMinAlongDimNegative_basic", - "ReduceMinAlongDimSignedInt_basic", - "ReduceMinAlongDimUnsignedInt_basic", - "ReduceMinAlongDim_basic", - "ReduceMinKeepDimReturnBoth_basic", - "ReduceMinKeepDim_basic", "ReduceProdDimIntFloatModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", @@ -721,15 +907,21 @@ "ReflectionPad2dModule_Right", "ReflectionPad2dModule_Top", "ReflectionPad2dModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleTop_basic", + "ReflectionPad3dModuleBottom_basic", + "ReflectionPad3dModuleLeft_basic", + "ReflectionPad3dModuleRight_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", "ReplicationPad2dModule_basic", "ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", - "RsubInt0d_NumToTensor_Module_basic", - "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", - "ScaledDotProductAttentionDifferentModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScatterAddDynamicModule_basic", "ScatterReduceFloatMaxModule", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMeanModule", @@ -755,8 +947,6 @@ "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", "SortTensorDescending_basic", "SortTensorInteger_basic", "SortTensorNegativeDimension_basic", @@ -773,14 +963,6 @@ "TensorToFloatZeroRank_basic", "TensorToFloat_basic", "TensorToInt_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "Threshold1dFloatModule_basic", - "Threshold1dIntI32Module_basic", - "Threshold1dIntModule_basic", - "Threshold2dFloatModule_basic", - "Threshold2dIntModule_basic", - "Threshold3dFloatModule_basic", - "Threshold3dIntModule_basic", "ThresholdBackward1dFloatModule_basic", "ThresholdBackward1dIntModule_basic", "ThresholdBackward1dMixedModule_basic", @@ -790,25 +972,78 @@ "ThresholdBackward3dFloatModule_basic", "ThresholdBackward3dIntModule_basic", "ThresholdBackward3dMixedModule_basic", - "TorchPrimLoopForLikeModule_basic", - "TorchPrimLoopWhileLikeModule_basic", "TraceModule_basic", "TraceModule_empty", "TraceModule_nonsquare", "TraceSignedIntModule_basic", "TraceUnsignedIntModule_basic", "TraceUnsignedIntModule_empty", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", - "VarMeanBiasedModule_basic", - "VarMeanCorrectionNoneModule_basic", - "VarMeanUnbiasedModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", + # Error: `aten.as_strided` op is not supported + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SplitWithSizes_Module_basic", + "Unfold_Module_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Dynamic_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AddIntModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemIntOpModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "InterpolateDynamicModule_sizes_nearest", + "IouOfModule_basic", + "IscloseStaticModuleTrue_basic", + "IscloseStaticModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "MulIntModule_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ScalarImplicitIntModule_basic", + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", + "SubIntModule_basic", + "TensorToIntZeroRank_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", + "BernoulliFloatModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", + "ScaledDotProductAttentionGQAModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -819,16 +1054,62 @@ "ResNet18StaticModule_basic", "MobilenetV3Module_basic", "Conv2dBiasNoPaddingModule_basic", + # llvm-project/llvm/include/llvm/ADT/ArrayRef.h:257: + # const T &llvm::ArrayRef::operator[](size_t) const [T = long]: + # Assertion `Index < Length && "Invalid index!" + "IndexPutWithNoneAndBroadcastModule_basic", + # Assertion `newMaterialization.getType() == outputType + # materialization callback produced value of incorrect type failed + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "AtenNonzero1DDynamicModule_basic", # error: Mismatched ranks of types2 vs 1 } STABLEHLO_PASS_SET = { + "ReduceAminmaxSingleDim_basic", + "ReduceAminmaxAllDims_basic", + "ReduceAmaxEmptyDim_basic", + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDim_basic", + "ArgminModule_with_dim", + "ReduceMinAlongDimSignedInt_basic", + "ReduceAnyDimFloatModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "SplitWithSizes_Module_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", + "EmptyModule_uint8", + "TypeConversionUint8ToF32Module_basic", + "Atleast1dModule0dInput_basic", + "Atleast1dModule1dInput_basic", + "Atleast2dModule0dInput_basic", + "Atleast2dModule1dInput_basic", + "Atleast2dModule2dInput_basic", + "AtenLinear1D_basic", + "AtenLinear2D_basic", + "AtenLinear3DBias_basic", + "AtenLinearMatVec_basic", + "AtenLinearVecMatBias_basic", + "AtenLinearVecMat_basic", + "ReduceAminSingleDim_basic", + "AtenDotModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "AliasModule_basic", + "TrueFalseOrBoolOpModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", @@ -852,24 +1133,31 @@ "ArgmaxModule_with_dim", "AtenComplex64Module_basic", "AtenFloatScalarModule_basic", + "AtenHannWindowPeriodicFalseModule_basic", + "AtenHannWindowPeriodicTrueModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", - "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", "AtenMmFloatTypes_basic", "AtenMmIntTypes_basic", + "AtenIntMM_basic", "AtenRoundFloatHalfToEvenModule_basic", "AtenRoundFloatModule_basic", "AtenRoundIntModule_basic", "AtenSubFloatModule_basic", "AtenToDeviceModule_basic", + "AtenTrilStaticModule_basic", + "AtenTrilWithNegDiagonalStaticModule_basic", + "AtenTrilWithPosDiagonalStaticModule_basic", "Aten_CastFloatModule_basic", "Aten_CastLongModule_basic", "AvgPool1dStaticModule_basic", "AvgPool2dStaticModule_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", + "AvgPool3dStaticModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmStaticModule_basic", @@ -895,6 +1183,7 @@ "ContainsIntList_False", "ContainsIntList_True", "ContiguousModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", @@ -903,6 +1192,9 @@ "Convolution2DStaticModule_basic", "ConvolutionBackwardModule2DStatic_basic", "ConvolutionModule2DTransposeStridedStatic_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dStaticModule_basic", "ConstantPad2dStaticModule_basic", "ConstantPadNdModule_basic", "ConstantPadNdPartialStaticModule_basic", @@ -912,6 +1204,9 @@ "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DetachModule_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -988,6 +1283,7 @@ "ElementwiseLog2Module_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", + "ElementwiseNanToNumWithNoneModule_Basic", "ElementwiseNanToNumModule_Basic", "ElementwiseNeFloatTensorStaticModule_basic", "ElementwiseNeIntTensorStaticModule_basic", @@ -1000,15 +1296,25 @@ "ElementwisePreluStaticModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeIdentityModule_basic", @@ -1026,6 +1332,7 @@ "EmptyStridedModule_basic", "EqIntModule_basic", "ExpandAsIntModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineRoundToEvenModule_basic", "Fill_TensorFloat64WithFloat32Static_basic", @@ -1055,8 +1362,16 @@ "GeIntModule_basic", "GeluBackwardModule_basic", "GluStaticModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", "GtFloatIntModule_basic", "GtIntModule_basic", + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", "IntFloatModule_basic", @@ -1072,6 +1387,8 @@ "LinspaceTwoSizeModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic", + "MaskedFillTensorIntValueStaticModule_basic", + "MaskedScatterStaticBasic_basic", "Matmul4dStatic_basic", "Matmul_2d", "Matmul_dot", @@ -1098,6 +1415,7 @@ "MoveDimIntModule_basic", "MoveDimIntNegativeIndexModule_basic", "MulFloatModule_basic", + "MulFloatModule_basic", "MulIntModule_basic", "Mv_basic", "NarrowHorizontalTest2_basic", @@ -1221,6 +1539,10 @@ "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", "RollModule_basic", + "Rot90BasicModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", "RsubInt0d_NumToTensor_Module_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", @@ -1234,15 +1556,10 @@ "SliceOutOfLowerBoundStartIndexModule_basic", "SliceOutOfUpperBoundIndexModule_basic", "SliceOutOfUpperBoundIndexStaticModule_basic", - "SliceScatterModule_basic", - "SliceScatterNegativeDimModule_basic", - "SliceScatterNegativeEndModule_basic", - "SliceScatterStaticModule_basic", - "SliceScatterStepVariationModule_basic", - "SliceScatterZeroDimModule_basic", "SliceSizeTwoStepModule_basic", "SliceStartEqEndModule_basic", "SliceStaticModule_basic", + "SliceStaticComplexInputModule_basic", "SliceWholeTensorModule_basic", "SortIntListReverse_basic", "SortIntList_basic", @@ -1270,6 +1587,9 @@ "TensorToFloatZeroRank_basic", "TensorToIntZeroRank_basic", "TensorsConcatModule_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex64FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", "TensorsConcatNegativeDimModule_basic", "TensorsConcatNegativeDimStaticModule_basic", "TensorsConcatPromoteDTypeModule_basic", @@ -1285,6 +1605,13 @@ "TorchPrimLoopForLikeTensorArgModule_basic", "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", + "TriuIndicesModule_basic", + "TriuIndicesAllZerosModule_basic", + "TriuIndicesNegativeOffsetModule_basic", + "TrilIndicesAllZerosModule_basic", + "TrilIndicesModule_basic", + "TrilIndicesNegativeOffsetModule_basic", + "TrilIndicesOfssetGreaterThanRowModule_basic", "TupleModule_basic", "TypeAsDifferentModule_basic", "TypeAsSameModule_basic", @@ -1360,7 +1687,6 @@ "RandModule_basic", "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", - "SelectScattertStaticModule_basic", "SqueezeDimModule_static", "SqueezeModule_static", "TriuBroadcastModule_basic", @@ -1442,24 +1768,373 @@ "ElementwiseLogSigmoidModule_basic", "ElementwiseHardshrinkStaticModule_basic", "ElementwiseSoftshrinkStaticModule_basic", + "RenormModuleFloat16_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", } STABLEHLO_CRASHING_SET = { - "AtenEmbeddingBagSumExample_basic", + "IndexPutWithNoneAndBroadcastModule_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + # LLVM ERROR: Failed to infer result type(s) + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", +} + +TOSA_CRASHING_SET = { + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutModule_basic", + "ScatterSrcStaticModule_basic", + # Runtime op verification: Out of bounds access + "ReduceAllDimEmpty_basic", + # SmallVector unable to grow for ThresholdBackward1d + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", +} + +FX_IMPORTER_TOSA_CRASHING_SET = { + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "HBC_basic", + # subview is out-of-bounds of the base memref + "RollModule_basic", + # element type cannot be iterated + "TriuModule_basic", + # Randomly mismatching values + "ConvolutionModule2DTranspose_basic", + # ? + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + # out-of-bounds access + "UpSampleNearest2d_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dStaticFactor_basic", + # only on stable: mismatched number of results + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + # Crash in tosa to tensor: inferReshapeCollapsedType(TensorType, TensorType): Assertion `lhsShape[currLhsDim] == 1' failed. + "TrilIndicesAllZerosModule_basic", + "TriuIndicesAllZerosModule_basic", } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleInt2D_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleInt2D_basic", + "MaskedFillScalarFloatValueModule_basic", + "MaskedFillScalarFloatValueStaticModule_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", + "Threshold3dIntModule_basic", + "TrilIndicesModule_basic", + "TrilIndicesOfssetGreaterThanRowModule_basic", + "TriuIndicesNegativeOffsetModule_basic", + "BmmFloat16Module_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "LinspaceDtypeModule_basic", + "Aten_CastLongModule_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseIntTensorLtFloatScalarModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + "ElementwiseUnaryIntModule_basic", + "PowIntFloatModule_basic", + "Deg2radModule_basic", + "ElementwiseIntTensorLtFloatTensorModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossNoReductionModule_basic", + "L1LossSumReductionModule_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RenormModuleFloat16_basic", + "SplitDimStaticModule_basic", + "Deg2radModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog1pModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseLogitModule_basic", + "ElementwiseMishModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossNoReductionModule_basic", + "L1LossSumReductionModule_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "SoftplusModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ElementwiseAtenLogicalNotOpPromoteModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSinIntModule_basic", + "FloatPowerTensorTensorStaticModule_basic", + "AdaptiveMaxPool1dDimOneStatic_basic", + "CollapseAllDimensionsModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseLdexpModule_basic", + "ElementwiseSignbitIntModule_basic", + "Exp2StaticIntModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "RsubIntModule_noalpha_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModule_basic", + "ElementwiseAddBoolModule_basic", + "Exp2StaticModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "DropoutTrainStaticShapeModule_basic", + "ElementwiseAtenLogicalAndOpModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseRreluTrainStaticModule_basic", + "IndexSelectRank0IdxModule_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "SliceCopy_Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dIntModule_basic", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EmptyModule_uint8", + "EmptyStridedModule_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "SelectScattertStaticModule_basic", + "SliceScatterStaticModule_basic", + "TensorAlloc1dStaticModule_basic", + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", + "AtenLinalgCrossBroadcast_basic", + "AtenLinalgCrossCustomDim_basic", + "AtenLinalgCrossFloat_basic", + "AtenLinalgCrossInt_basic", + "AtenLinalgCrossNegativeDim_basic", + "BinaryCrossEntropyWithLogitsStaticModule_basic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "DiagonalWithStaticShapeModule_basic", + "EinsumStaticDiagonalDimensionModule_basic", + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideScalarModule_basic", + "ElementwiseAtenFloorDivideScalarNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorPositiveModule_basic", + "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeFloorModule_basic", + "ElementwiseDivScalarRoundingModeFloorStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncModule_basic", + "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorModule_basic", + "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncModule_basic", + "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", + "ElementwiseGeFloatTensorModule_basic", + "ElementwiseGeIntTensorModule_basic", + "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + "AtenHannWindowPeriodicFalseModule_basic", + "AtenHannWindowPeriodicTrueModule_basic", + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseAtenLogicalNotOpModule_basic", + "ElementwiseAtenLogicalXorOpModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseCosModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseLeFloatIntScalarModule_basic", + "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeFloatTensorNanModule_basic", + "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseSinModule_basic", + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_keepDim", + "ReduceAllDimBool_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinFloatModule_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdDtypeIntModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "AtenTrilStaticModule_basic", + "AtenTrilWithNegDiagonalStaticModule_basic", + "AtenTrilWithPosDiagonalStaticModule_basic", + "ArgmaxKeepdimModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", + "Atleast1dModule0dInput_basic", + "Atleast1dModule1dInput_basic", + "Atleast2dModule0dInput_basic", + "Atleast2dModule1dInput_basic", + "Atleast2dModule2dInput_basic", + "AtenLinear2D_basic", + "AtenLinear3DBias_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseDivTensorFloatModule_basic", + "ElementwiseMulTensorFloatModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "NativeGroupNormModule_basic", + "AtenDotModule_basic", "ElementwiseFloatTensorGtIntScalarModule_basic", "ElementwiseLogSigmoidModule_basic", + "ElementwiseTernaryStaticShapeModule_basic", "ElementwiseTruncModule_basic", "ElementwiseTruncIntModule_basic", "ElementwiseSgnModule_basic", "ElementwiseSignIntModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", "AddCDivModule_basic", "AddCDiv_Module_basic", "AddCMulModule_basic", @@ -1474,7 +2149,6 @@ "ArangeStartOutModule_basic", "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", - "ArangeZeroElementOutputModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", "ArangeFloatModule_basic", @@ -1500,6 +2174,7 @@ "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", "Aten_CastFloatModule_basic", + "TrueFalseOrBoolOpModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", "BaddbmmDynamicModule_basic", @@ -1538,6 +2213,8 @@ "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", "Convolution2DStaticModule_basic", "CosineSimilarityStaticModule_basic", "DetachModule_basic", @@ -1608,6 +2285,11 @@ "ElementwiseFlattenBroadcastModule_basic", "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", + "ElementwiseFmaxModule_basic", + "ElementwiseFminModule_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", "ElementwiseGeFloatIntScalarModule_basic", "ElementwiseGeFloatScalarModule_basic", "ElementwiseGeIntScalarModule_basic", @@ -1661,13 +2343,22 @@ "ElementwisePowModule_basic", "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", + "ElementwiseRad2DegModule_basic", + "ElementwiseRad2DegIntModule_basic", "ElementwiseReciprocalModule_basic", "ElementwiseRelu6Module_basic", "ElementwiseReluModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Float_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", @@ -1685,6 +2376,7 @@ "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseNanToNumWithNoneModule_Basic", "ElementwiseNanToNumModule_Basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", @@ -1708,6 +2400,7 @@ "HardswishRandomModule_basic", "HardtanhBackward_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorNegativeIndexModule_basic", "IndexTensorStaticModule_basic", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", @@ -1728,6 +2421,7 @@ "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dStaticModule_basic", "MeanModule_basic", "MmDagModule_basic", @@ -1760,7 +2454,6 @@ "NormScalarOptDimModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", - "NumpyTRank0Module_basic", "NumpyTRank1Module_basic", "NumpyTRank2Module_basic", "NumpyTRankNDynamicModule_basic", @@ -1772,9 +2465,10 @@ "OnesModuleInt_basic", "PadModule_basic", "PadWithNoneValModule_basic", - "Permute0RankModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", + "PowFloatFloatModule_basic", + "PowFloatIntModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", @@ -1791,9 +2485,12 @@ "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", "RepeatModule_basic", "RepeatInterleaveSelfIntNoDimModule_basic", - "ResNet18StaticModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeAsModule_basic", @@ -1801,16 +2498,17 @@ "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", + "ResNet18StaticModule_basic", "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", + "RsubIntModule_basic", "ScalarTensorDefaultDtypeModule_basic", "ScalarTensorFloat32Module_basic", "ScalarTensorInt32Module_basic", "ScalarTensorInt64Module_basic", "SelectIntNegativeDimAndIndexStaticModule_basic", "SiluModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", "SliceStaticModule_basic", "SplitTensorGetItem_Module_basic", "SplitTensorLastSmallerModule_basic", @@ -1833,8 +2531,6 @@ "TensorIntModule_basic", "TensorLiteralModule_basic", "TensorOpaqueLiteralModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsConcatStaticModule_basic", "TestF16Return_basic", "TestMultipleTensorReturn_basic", "Threshold1dFloatModule_basic", @@ -1900,70 +2596,10 @@ "LinspaceModule_basic", "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", - "TorchPrimLoopForLikeTensorArgModule_basic", -} - -MAKE_FX_TOSA_PASS_SET = ( - TOSA_PASS_SET - | { - ### Tests additionally passing in make_fx_tosa - "MaxPool1dEmptyStrideStaticModule_basic", - "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxPool1dStaticModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dStaticEvenMultiple_basic", - "CosineSimilarityModule_basic", - "NativeGroupNormBackwardModule_basic", - "ReduceFrobeniusNormKeepDimModule_basic", - "ReduceFrobeniusNormModule_basic", - "SliceWholeTensorModule_basic", - "TensorFloatModule_basic", - "TensorIntModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "RepeatInterleaveSelfIntModule_basic", - "TorchPrimLoopForLikeTensorArgModule_basic", - "ViewSizeDimFollowedByCollapsedOnesModule_basic", - "ViewSizeDimFollowedByExpandedOnesModule_basic", - "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", - "ViewSizeDimLedByCollapsedOnesModule_basic", - "ViewSizeFromOtherTensor_basic", - } -) - { - ### Test failing in make_fx_tosa but not in tosa - # Dynamic shape, has extra unsupported broadcast ops - "Matmul_3d", - "MatmulStaticBroadcast_basic", - # Unimplemented operator 'aten._index_put_impl_.hacked_twin' - "IndexPutImpl1DFloatNonAccumulateModule_basic", - "IndexPutImpl1DIntNonAccumulateModule_basic", - # RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 1 - "Add_Module_basic", - # failed to legalize operation 'torch.aten.to.dtype' that was explicitly marked illegal - "AtenEyeModuleInt2D_basic", - "AtenEyeMModuleInt2D_basic", - "Conv2dBiasNoPaddingModule_basic", - "Conv2dNoPaddingModule_basic", - "Conv2dWithPaddingDilationStrideModule_basic", - "Conv2dWithPaddingModule_basic", - "AtenInstanceNormModule_basic", - # failed to legalize operation 'torch.operator' - "ElementwisePreluModule_basic", - "ElementwisePreluStaticModule_basic", - "ElementwiseLogSigmoidModule_basic", - # Shape Related failures - "PrimListUnpackNumMismatchModule_basic", - "ReshapeExpandModule_basic", - "UnsafeViewCollapseModule_basic", - "UnsafeViewDynamicExpandModule_basic", - "ViewCollapseModule_basic", - "ViewDynamicExpandCollapseModule_basic", - "ViewDynamicExpandModule_basic", - "ViewExpandDynamicDimModule_basic", - "ViewNoChange1dModule_basic", - "ViewNoChange2dModule_basic", - "ViewNoChange3dModule_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", } LTC_CRASHING_SET = { @@ -1997,6 +2633,7 @@ "_ConvolutionDeprecated2DDeterministicModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "ArangeStartOutViewModule_basic", "AtenIntBoolOpModule_basic", "BernoulliTensorModule_basic", @@ -2109,25 +2746,25 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_depthwise", + "Conv2dQInt8Module_grouped", + "Conv2dQInt8Module_not_depthwise", + "Conv2dQInt8PerChannelModule_basic", + "Conv2dQInt8PerChannelModule_depthwise", + "Conv2dQInt8PerChannelModule_grouped", "ConvTranspose2DQInt8_basic", } ONNX_XFAIL_SET = { + "ToDtypeIntFromFloatModule_basic", + # This test is expected to time out + "TimeOutModule_basic", # Failure - cast error "PermuteNegativeIndexModule_basic", - # Failure - expand multiple dynamic dims - "EmbeddingModuleF16_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - "IndexTensorSelectDimModule_basic", # Failure - incorrect numerics + "ReduceAnyDimFloatModule_basic", "AvgPool2dDivisorOverrideModule_basic", "BroadcastDynamicDimModule_basic", "ElementwiseAtan2TensorIntModule_basic", @@ -2136,31 +2773,21 @@ "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", - "FlipModuleStaticShape_basic", - "FlipNegativeIndexModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", + "ElementwiseFminModule_basic", + "ElementwiseFmaxModule_basic", + "Exp2StaticModule_basic", + "FloatPowerTensorTensorStaticModule_basic", + "MultinomialModule2D_basic", + "MultinomialModule2D_F32", "PixelShuffleModuleStaticRank4Float32_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", "SliceCopyEndGreaterThanDimSize_Module_basic", "SliceCopyNegative_Module_basic", "SliceCopyNonZeroDim_Module_basic", "SliceCopy_Module_basic", + "SliceStaticComplexInputModule_basic", "StdCorrectionLargeInputModule_basic", "TupleModule_basic", + "ThresholdStaticModule_basic", "VarCorrectionLargeInputModule_basic", # Failure - incorrect shape "ArangeStartOutDtypeModule_basic", @@ -2168,6 +2795,8 @@ "MoveDimIntNegativeIndexModule_basic", "ReduceL3NormKeepDimModule_basic", "ViewSizeFromOtherTensor_basic", + # incorrect shape generated by torch.onnx.export (needs an unsqueeze) + "MultinomialModule_basic", # Failure - onnx_export "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", @@ -2176,6 +2805,7 @@ "AdaptiveAvgPool2dDynamic_basic", "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveAvgPool3dDynamicNoBatch_basic", "AdaptiveAvgPool3dDynamic_basic", "AdaptiveMaxPool1dDynamicNoBatch_basic", @@ -2193,6 +2823,7 @@ "AdaptiveMaxPool3dStatic_basic", "AddCDivModule_basic", "AddIntModule_basic", + "AddFloatIntModule_basic", "Add_Module_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", @@ -2210,14 +2841,22 @@ "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", + "AtenIntMM_basic", "AtenItemFpOpModule_basic", "AtenItemIntOpModule_basic", + "AtenKthvalueModule_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", "AtenLinalgCrossDynamic_basic", "AtenMatmulQMixedSigni8Transpose_basic", "AtenMatmulQMixedSigni8_basic", @@ -2228,6 +2867,8 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", + "AtenPolarFloatModule_basic", + "AtenPolarDoubleModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenSubFloatModule_basic", @@ -2253,17 +2894,33 @@ "CollapsePartialDynamicModule_basic", "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", + "ColumnStackBasicIntModule_basic", + "ColumnStack1dModule_basic", + "ColumnStack0dModule_basic", "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", + "Conv1dWithSamePaddingModule_basic", + "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_depthwise", + "Conv2dQInt8Module_grouped", + "Conv2dQInt8Module_not_depthwise", + "Conv2dQInt8PerChannelModule_basic", + "Conv2dQInt8PerChannelModule_depthwise", + "Conv2dQInt8PerChannelModule_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -2277,26 +2934,17 @@ "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", + "Deg2radModule_basic", "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", - "ElementwiseAndScalarModule_basic", - "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic", - "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndScalarInt32Module_basic", - "ElementwiseBitwiseAndScalarInt64Module_basic", - "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseLeftShiftInt32Module_basic", - "ElementwiseBitwiseLeftShiftInt64Module_basic", - "ElementwiseBitwiseLeftShiftInt8Module_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrModule_basic", @@ -2319,19 +2967,31 @@ "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", + "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseOrTensorModule_basic", "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseRad2DegModule_basic", + "ElementwiseRad2DegIntModule_basic", "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSgnModule_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", "EqIntModule_basic", "ExponentialModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", "FloatImplicitModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", @@ -2340,6 +3000,10 @@ "GtFloatIntModule_basic", "GtIntModule_basic", "HardtanhBackward_basic", + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", "IndexPutImpl1DFloatAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntAccumulateModule_basic", @@ -2352,6 +3016,7 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", "IntFloatModule_basic", "IntImplicitModule_basic", "IouOfModule_basic", @@ -2359,6 +3024,9 @@ "IsFloatingPointInt_False", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", + "L1LossNoReductionModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossSumReductionModule_basic", "LeakyReluBackwardModule_basic", "LeakyReluBackwardStaticModule_basic", "LenStrModule_basic", @@ -2367,10 +3035,7 @@ "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", "MaxPool1dCeilModeTrueModule_basic", - "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dModule_basic", - "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxPool1dStaticModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dModule_basic", "MaxPool2dWithIndicesAllOnesModule_basic", @@ -2388,6 +3053,15 @@ "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", "MaxPool3dModule_basic", + "MaxPool3dWithIndicesAllOnesModule_basic", + "MaxPool3dWithIndicesCeilModeTrueModule_basic", + "MaxPool3dWithIndicesFullSizeKernelModule_basic", + "MaxPool3dWithIndicesModule_basic", + "MaxPool3dWithIndicesNonDefaultDilationModule_basic", + "MaxPool3dWithIndicesNonDefaultParamsModule_basic", + "MaxPool3dWithIndicesNonDefaultStrideModule_basic", + "MaxUnpool3dModule_basic", + "MaxUnpool3dModulePad0_basic", "MeanDimEmptyDimModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", @@ -2426,9 +3100,13 @@ "NllLossModuleBackward_ignore_index", "NllLossModule_1D_basic", "NllLossModule_basic", + "NllLossStaticModule_basic", + "NllLossStaticModule_weight_basic", "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", + "NllLossStaticModule_mean_basic", "NllLossModule_sum_basic", + "NllLossStaticModule_sum_basic", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", @@ -2443,7 +3121,7 @@ "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", - "PowIntFloatModule_basic", + "PowIntIntModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -2462,9 +3140,20 @@ "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", + "RreluWithNoiseForwardBackwardModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", "ReshapeExpandModule_basic", + "Rot90DynamicDimsModule_basic", + "SafeSoftmaxModule_basic", + "SafeSoftmaxNonNoneDtypeModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", @@ -2482,6 +3171,9 @@ "ScatterReduceIntSumModule", "SelectScattertModule_basic", "SelectScattertStaticModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", "SliceEndSleStartModule_basic", "SliceOutOfUpperBoundIndexModule_basic", "SliceScatterModule_basic", @@ -2498,11 +3190,12 @@ "SplitDimStaticModule_basic", "SqrtIntConstantModule_basic", "SqrtIntModule_basic", - "StdCorrectionEmptyDimModule_basic", - "StdDimEmptyDimModule_basic", "SubFloatModule_basic", "SubIntModule_basic", "TanhBackward_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex64FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", "TensorToFloatZeroRank_basic", @@ -2553,8 +3246,6 @@ "UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2d_basic", - "VarCorrectionEmptyDimModule_basic", - "VarDimEmptyDimModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseModule_basic", "ViewDynamicExpandCollapseModule_basic", @@ -2565,6 +3256,7 @@ "ViewNoChange1dModule_basic", "ViewNoChange2dModule_basic", "ViewNoChange3dModule_basic", + "WeightNormInterfaceModule_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", @@ -2576,6 +3268,10 @@ "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", # Failure - onnx_import + # Failure - onnx_lowering: onnx.SplitToSequence + "ChunkListUnpackUneven_Module_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", # these diagonal modules are currently failing due to dynamic shape. @@ -2583,10 +3279,6 @@ # when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here. "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", - # Failure - onnx_lowering: onnx.MaxPool - "MaxPool2dWithIndicesAllNegativeValuesModule_basic", - "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", - "MaxPool2dWithIndicesStaticModule_basic", # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdFloatModule_basic", "ReduceProdDtypeFloatModule_basic", @@ -2605,9 +3297,6 @@ "BernoulliTensorModule_basic", # Failure - onnx_lowering: onnx.ReduceProd "ReduceProdDimIntFloatModule_basic", - # Failure - onnx_lowering: onnx.Resize - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticSize_basic", # Failure - onnx_lowering: onnx.ScatterElements "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", @@ -2616,38 +3305,43 @@ "ScatterValueFloatModule_basic", # Failure - onnx_lowering: onnx.ScatterND "IndexPut1DFloatAccumulateModule_basic", - "IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", - "IndexPut1DIntNonAccumulateModule_basic", "IndexPut2DFloatAccumulateModule_basic", - "IndexPut2DFloatNonAccumulateModule_basic", "IndexPut2DIntAccumulateModule_basic", - "IndexPut2DIntNonAccumulateModule_basic", "IndexPut3DFloatAccumulateModule_basic", - "IndexPut3DFloatNonAccumulateModule_basic", "IndexPut3DIntAccumulateModule_basic", - "IndexPut3DIntNonAccumulateModule_basic", "IndexPutHackedTwin1DFloatAccumulateModule_basic", - "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", "IndexPutHackedTwin1DIntAccumulateModule_basic", - "IndexPutHackedTwin1DIntNonAccumulateModule_basic", "IndexPutHackedTwin2DFloatAccumulateModule_basic", - "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", "IndexPutHackedTwin2DIntAccumulateModule_basic", - "IndexPutHackedTwin2DIntNonAccumulateModule_basic", "IndexPutHackedTwin3DFloatAccumulateModule_basic", - "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", - "IndexPutHackedTwin3DIntNonAccumulateModule_basic", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", + # unimplemented torchvision.deform_conv2d torch->linalg + "DeformConv2D_basic", + # Error: 'aten::renorm' to ONNX opset version 17 is not supported. + "RenormModuleFloat16_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", + "RenormModuleFloat32DynamicDims_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDymsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", # Failure - unknown "BernoulliModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv_Transpose1dModule_basic", + "Conv_Transpose3dModule_basic", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", @@ -2662,12 +3356,42 @@ "ElementwiseTanIntModule_basic", "ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseUnaryIntModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", + "ElementwiseSignbitModule_basic", + "ElementwiseSignbitIntModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseLdexpModule_basic", + "Exp2StaticIntModule_basic", "MaskedFillTensorFloatValueModule_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", "ReduceAnyFloatModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + "UnfoldModule_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Dynamic_basic", + "ViewDtypeStaticModule_basic", + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + "PowIntIntModule_basic", + "PrimsSumFloatModule_basic", + "RepeatInterleaveFillModule_basic", + "RepeatInterleaveModule_basic", + "RepeatInterleaveStaticModule_basic", + "SliceCopyMax_Module_basic", + "Aten_TrilinearModule_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleZerodDimBug_basic", + "ScaledDotProductAttentionGQAModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -2676,8 +3400,57 @@ "RepeatInterleaveSelfIntNoDimModule_basic", } +if torch_version_for_comparison() < version.parse("2.4.0.dev"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::bitwise_left_shift' to ONNX opset version 17 is not supported. + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + # bitwise and support has been added in torch nightly + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + } + +if torch_version_for_comparison() >= version.parse("2.5.0.dev"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # ERROR: value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=+nan, max=+nan, mean=+nan) is not close to golden value (Tensor with shape=[2, 3, 8, 20], dtype=torch.float32, min=-2.394, max=+2.454, mean=-0.02828) + "ScaledDotProductAttentionBoolMaskModule_basic", + } + +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } + +if torch_version_for_comparison() < version.parse("2.4.0.dev"): + STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { + "AtenIntMM_basic", + } + FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | { + "AtenIntMM_basic", + } + +if torch_version_for_comparison() > version.parse("2.4.0.dev"): + STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { + "ElementwiseCreateComplexModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + } + FX_IMPORTER_STABLEHLO_XFAIL_SET = FX_IMPORTER_STABLEHLO_XFAIL_SET | { + "ElementwiseCreateComplexModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + } + -ONNX_CRASHING_SET = { +ONNX_CRASHING_SET = LINALG_CRASHING_SET | { "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineDynamicShapeModule_basic", "ElementwisePreluModule_basic", @@ -2689,4 +3462,1621 @@ # The following test sporadically stopped producing correct numerics for the golden value in the CI. # For now, we are removing the test until this issue has been debugged. "QuantizedMLP_basic", + # Runtime crash: mismatched size for broadcast + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool3dWithIndicesAllNegativeValuesModule_basic", + "MaxPool3dWithIndicesNonDefaultPaddingModule_basic", + "StdDimEmptyDimModule_basic", + "StdCorrectionEmptyDimModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", + # Runtime op verification: rank mismatch in memref.cast + "ViewSizeFromOtherTensor_basic", } + +FX_IMPORTER_TOSA_XFAIL_SET = { + "AtenSymConstrainRangeForSize_basic", + "AtenSymConstrainRange_basic", + "Aten_AssertScalar_basic", + "ScatterAddDynamicModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", + "IsInfiniteModule_basic", + "LayerNormFwAndBwModule_basic", + "LayerNormManualFwAndBwModule_basic", + "SelfAttentionFwAndBwModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseSignbitModule_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleZerodDimBug_basic", + "AtenNonzero1DDynamicModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticModule_basic", + "ViewDtypeStaticModule_basic", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "ElementwiseCreateComplexModule_basic", + "AtenPolarDoubleModule_basic", + "AtenPolarFloatModule_basic", + "HstackBasicComplexModule_basic", + "AtenIntMM_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueModule_basic", + "AvgPool3dStaticModule_basic", + "Conv_Transpose1dModule_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose3dStaticModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", + "MaskedScatterStaticBasic_basic", + "MaxUnpool3dModulePad0_basic", + "MaxUnpool3dModule_basic", + "MultinomialModule2D_F32", + "MultinomialModule2D_basic", + "MultinomialModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScatterAddStaticModule_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", + "TensorsConcatComplex64FloatModule_basic", + "TimeOutModule_basic", + "TypeConversionUint8ToF32Module_basic", + "WeightNormInterfaceModule_basic", + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AdaptiveMaxPool3dDynamicNoBatch_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dStatic_basic", + "AddIntModule_basic", + "AddFloatIntModule_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeStartOutViewModule_basic", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", + "AtenMatmulQint8_basic", + "AtenMmIntTypes_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool1dFloatModule_basic", + "AvgPool1dIntModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "BatchMlpLayerModule_basic", + "BatchNorm1DModule_basic", + "BatchNorm1DWith2DInputModule_basic", + "BatchNorm2DModule_basic", + "BatchNorm3DModule_basic", + "BernoulliFloatModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BmmIntModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BroadcastDynamicDimModule_basic", + "CeilFloatModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv1dNoPaddingTransposeModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dModule_basic", + "Conv1dNoPaddingGroupModule_basic", + "Conv1dNoPaddingModule_basic", + "Conv2dBiasNoPaddingModule_basic", + "Conv2dNoPaddingModule_basic", + "Conv2dWithPaddingDilationStrideModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingModule_basic", + "Conv1dWithSamePaddingModule_basic", + "Conv1dWithValidPaddingModule_basic", + "Conv1dGroupModule_basic", + "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_depthwise", + "Conv2dQInt8Module_grouped", + "Conv2dQInt8Module_not_depthwise", + "Conv2dQInt8PerChannelModule_basic", + "Conv2dQInt8PerChannelModule_depthwise", + "Conv2dQInt8PerChannelModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Conv2dWithValidPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", + "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", + "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", + "Conv_Transpose2dModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeStrided_basic", + "ConvolutionModule2DTranspose_basic", + "CumsumModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", + "DeformConv2D_basic", + "DeterminantBatchedModule_F32", + "DeterminantDynamicModule_F32", + "DeterminantModule_F32", + "DivFloatModule_basic", + "DivIntModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtan2FloatIntModule_basic", + "ElementwiseAtan2FloatIntStaticModule_basic", + "ElementwiseAtan2TensorFloatModule_basic", + "ElementwiseAtan2TensorFloatStaticModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", + "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseCreateComplexModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "ElementwiseSignbitModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32Static_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", + "EqIntModule_basic", + "ExponentialModule_basic", + "FloatImplicitModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "IsInfiniteModule_basic", + "LayerNormLastDimModule_basic", + "LayerNormModule_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "LenStrModule_basic", + "LinalgNormKeepDimComplexModule_basic", + "LinalgVectorNormComplexModule_basic", + "MaskedScatterStaticBasic_basic", + "MaxPool1dCeilModeTrueModule_basic", + "MaxPool1dModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dWithIndicesAllNegativeValuesModule_basic", + "MaxPool3dWithIndicesAllOnesModule_basic", + "MaxPool3dWithIndicesCeilModeTrueModule_basic", + "MaxPool3dWithIndicesFullSizeKernelModule_basic", + "MaxPool3dWithIndicesModule_basic", + "MaxPool3dWithIndicesNonDefaultDilationModule_basic", + "MaxPool3dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool3dWithIndicesNonDefaultParamsModule_basic", + "MaxPool3dWithIndicesNonDefaultStrideModule_basic", + "MaxPool3dWithIndicesStaticModule_basic", + "MeanDimEmptyDimModule_basic", + "MlGroupNormManualModule_basic", + "MlGroupNormModule_basic", + "MlLayerNormManualModule_basic", + "MlLayerNormModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeGroupNormBackwardModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NormScalarComplexModule_basic", + "NormScalarModule_basic", + "NormScalarOptDimKeepDimComplexModule_basic", + "NormalFunctionalModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "PowIntIntModule_basic", + "PowIntIntModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedMLP_basic", + "QuantizedNoLayer_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", + "QuantizedSingleLayer_basic", + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + "RandnModule_basic", + "ReduceAllDimEmpty_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ReduceL1NormComplexModule_basic", + "ReduceL2NormComplexModule_basic", + "ReduceL3NormKeepDimComplexModule_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "RepeatInterleaveFillModule_basic", + "RepeatInterleaveModule_basic", + "RepeatInterleaveStaticModule_basic", + "RollModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScatterAddDynamicModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModule", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModule", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModule", + "ScatterReduceIntSumModuleIncludeSelf", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", + "SliceStaticComplexInputModule_basic", + "SliceCopyStartGreaterThanDimSize_Module_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceSizeTwoStepModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "ThresholdBackward2dMixedModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_empty", + "TraceUnsignedIntModule_empty", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewDtypeStaticModule_basic", + "ViewSizeFromOtherTensor_basic", + "VisionTransformerModule_basic", + # Unexpected failures due to new PyTorch version update + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "IouOfModule_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "Meshgrid_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticModule_basic", + "Mlp1LayerModule_basic", + "Mlp2LayerModuleNoBias_basic", + "Mlp2LayerModule_basic", + "MobilenetV3Module_basic", + "UniformModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + # Missing support for: torch.aten.Int.Tensor, + "AtenSymConstrainRangeForSize_basic", + "AtenSymConstrainRange_basic", + "Aten_AssertScalar_basic", +} + +if torch_version_for_comparison() < version.parse("2.6.0.dev"): + # Passing on stable but not on nightly + FX_IMPORTER_TOSA_XFAIL_SET -= { + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ChunkListUnpack_Module_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticDiagonalDimensionModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic", + "EinsumStaticWithEllipsisSlicingModule_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ExponentialModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", + "IouOfModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", + "Meshgrid_basic", + "MeshgridIndexingIJ_basic", + "MeshgridIndexingXY_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", + "RepeatInterleaveSelfIntModule_basic", + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "ScaledDotProductAttentionDifferentDynamicCausalModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScaledDotProductAttentionSameModule_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizes_Module_basic", + "SplitWithSizesListUnpackModule_basic", + "TensorsSplitTensorLastSmallerModule_basic", + "TensorsSplitTensorModule_basic", + "TensorsSplitTensorNegativeDimModule_basic", + "UniformModule_basic", + "UniformStaticShapeModule_basic", + } + # Failing on stable but not on nightly + FX_IMPORTER_TOSA_XFAIL_SET |= { + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "RsubInt0d_NumToTensor_Module_basic", + "AdaptiveMaxPool1dDimOneStatic_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + } + +ONNX_TOSA_CRASHING_SET = { + "ScatterSrcStaticModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdDimEmptyDimModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", + "ViewSizeFromOtherTensor_basic", +} + +ONNX_TOSA_XFAIL_SET = { + "AtenFftRfft2DLastDim_basic", + "AtenFftRfft2DMiddleDim_basic", + "AtenNonzero1DDynamicModule_basic", + "PowFloatIntModule_basic", + "PowIntFloatModule_basic", + "PowIntIntModule_basic", + "ColumnStack0dModule_basic", + "ColumnStack1dModule_basic", + "ColumnStackBasicIntModule_basic", + "Deg2radModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossNoReductionModule_basic", + "L1LossSumReductionModule_basic", + "FloatPowerTensorTensorStaticModule_basic", + "IsInfiniteModule_basic", + "ElementwiseCopysignModule_basic", + "ElementwiseFracModule_basic", + "ElementwiseLdexpModule_basic", + "ElementwiseSignbitIntModule_basic", + "ElementwiseSignbitModule_basic", + "Exp2StaticIntModule_basic", + "NllLossStaticModule_basic", + "NllLossStaticModule_mean_basic", + "NllLossStaticModule_sum_basic", + "NllLossStaticModule_weight_basic", + "Exp2StaticModule_basic", + "ElementwiseRreluWithNoiseEvalModule_basic", + "ElementwiseRreluWithNoiseEvalStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", + "RreluWithNoiseBackwardEvalModule_basic", + "RreluWithNoiseBackwardEvalStaticModule_basic", + "RreluWithNoiseBackwardTrainModule_basic", + "RreluWithNoiseBackwardTrainStaticModule_basic", + "RreluWithNoiseForwardBackwardModule_basic", + "Unfold_Module_Dynamic_basic", + "Unfold_Module_Rank_4", + "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Rank_Zero_basic", + "ViewDtypeStaticModule_basic", + "ArangeZeroElementOutputModule_basic", + "LinspaceEmptyModule_basic", + "RepeatInterleaveSelfIntNoDimModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "TrilIndicesAllZerosModule_basic", + "TriuIndicesAllZerosModule_basic", + "ElementwiseCreateComplexModule_basic", + "ReduceAllDimFloatModule_basic", + "ScaledDotProductAttentionDifferentCausalModule_basic", + "HstackBasicComplexModule_basic", + "HstackBasicFloatModule_basic", + "HstackBasicIntFloatModule_basic", + "HstackBasicIntModule_basic", + "Rot90BasicModule_basic", + "Rot90DynamicDimsModule_basic", + "Rot90MultipleRotationsModule_basic", + "Rot90NegativeEvenRotationsModule_basic", + "Rot90NegativeOddRotationsModule_basic", + "SafeSoftmaxModule_basic", + "SafeSoftmaxNonNoneDtypeModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "ArgmaxKeepdimModule_basic", + "AtenIntMM_basic", + "AtenKthvalueDynamicDimsModule_basic", + "AtenKthvalueFloat64DynamicDimsModule_basic", + "AtenKthvalueFloat64Module_basic", + "AtenKthvalueKeepDimModule_basic", + "AtenKthvalueModule_basic", + "AvgPool2dCountIncludePadFalseStaticModule_basic", + "AvgPool3dStaticModule_basic", + "Conv_Transpose1dModule_basic", + "Conv_Transpose1dStaticModule_basic", + "Conv_Transpose2dStaticModule_basic", + "Conv_Transpose3dModule_basic", + "Conv_Transpose3dStaticModule_basic", + "ElementwiseFmaxModule_basic", + "ElementwiseFminModule_basic", + "ElementwiseGeluApproximateTanhModule_basic", + "ElementwiseIntTensorLtFloatTensorModule_basic", + "ElementwiseRad2DegIntModule_basic", + "ElementwiseRad2DegModule_basic", + "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDividend_basic", + "ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "FakeQuantizePerTensorAffineCachemaskModule_basic", + "IndexPutWithNoneAndBroadcastModule_basic", + "MaskedScatterStaticBasic_basic", + "MaxUnpool3dModulePad0_basic", + "MaxUnpool3dModule_basic", + "MultinomialModule2D_F32", + "MultinomialModule2D_basic", + "MultinomialModule_basic", + "ReduceAmaxEmptyDim_basic", + "ReduceAminSingleDim_basic", + "ReduceAminmaxAllDims_basic", + "ReduceAminmaxSingleDim_basic", + "ReduceAnyDimFloatModule_basic", + "RenormModuleFloat16_basic", + "RenormModuleFloat32DynamicDims_basic", + "RenormModuleFloat32NegativeDim_basic", + "RenormModuleFloat32_basic", + "ScatterAddStaticModule_basic", + "TensorSplitSections_GetItemModule_basic", + "TensorSplitSections_ListUnpackModule_basic", + "TensorsConcatComplex128FloatModule_basic", + "TensorsConcatComplex128IntModule_basic", + "TensorsConcatComplex64FloatModule_basic", + "TimeOutModule_basic", + "TypeConversionUint8ToF32Module_basic", + "UnfoldModule_basic", + "WeightNormInterfaceModule_basic", + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AdaptiveMaxPool3dDynamicNoBatch_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dStatic_basic", + "AddCDivModule_basic", + "AddIntModule_basic", + "AddFloatIntModule_basic", + "AddSizeIntModule_basic", + "AddSizeIntNegDimModule_basic", + "Add_MixPModule_basic", + "Add_Module_basic", + "AddmmModuleFloat_basic", + "AddmmModule_broadcastable", + "AddmmModule_differentRankBroadcastable", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutViewModule_basic", + "ArgmaxIntModule_basic", + "ArgmaxIntModule_multiple_maxs", + "ArgmaxModule_basic", + "ArgmaxModule_with_dim", + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_with_dim", + "AtenComplex64Module_basic", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenLinalgCrossDynamic_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", + "AtenMatmulQint8_basic", + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", + "AtenPolarFloatModule_basic", + "AtenPolarDoubleModule_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "Aten_TrilinearModule_basic", + "Aten_TrilinearModuleSumdims_basic", + "Aten_TrilinearModuleSumAllDims_basic", + "Aten_TrilinearModuleVaryingRanks_basic", + "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", + "Aten_TrilinearModuleZerodDimBug_basic", + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "AtenTriuModule_basic", + "AtenTriuWithNegDiagonalModule_basic", + "AtenTriuWithPosDiagonalModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool1dFloatModule_basic", + "AvgPool1dIntModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "AvgPool2dWithoutPadModule_basic", + "BatchMlpLayerModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliOnesModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "BernoulliZerosModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BmmIntModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BroadcastDynamicDimModule_basic", + "BroadcastToModule_basic", + "BucketizeTensorFloatModule_basic", + "BucketizeTensorModule_basic", + "BucketizeTensorOutInt32RightModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "CollapseAllDimensionsModule_basic", + "CollapseFullDynamicModule_basic", + "CollapsePartialDynamicModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv1dModule_basic", + "Conv1dWithSamePaddingModule_basic", + "Conv1dWithValidPaddingModule_basic", + "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", + "Conv1dGroupModule_basic", + "Conv2dBiasNoPaddingModule_basic", + "Conv2dModule_basic", + "Conv2dNoPaddingModule_basic", + "Conv2dQInt8Module_basic", + "Conv2dQInt8Module_depthwise", + "Conv2dQInt8Module_grouped", + "Conv2dQInt8Module_not_depthwise", + "Conv2dQInt8PerChannelModule_basic", + "Conv2dQInt8PerChannelModule_depthwise", + "Conv2dQInt8PerChannelModule_grouped", + "Conv2dWithPaddingDilationStrideModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Conv2dWithPaddingModule_basic", + "Conv2dWithSamePaddingModule_basic", + "Conv2dWithValidPaddingModule_basic", + "Conv3dModule_basic", + "Conv3dWithSamePaddingModule_basic", + "Conv3dWithValidPaddingModule_basic", + "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", + "Conv_Transpose2dModule_basic", + "Convolution2DModule_basic", + "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "ConvolutionModule2DGroups_basic", + "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeStrided_basic", + "ConvolutionModule2DTranspose_basic", + "CopyModule_basic", + "CopyWithDifferentDTypesAndSizesModule_basic", + "CopyWithDifferentDTypesModule_basic", + "CopyWithDifferentSizesModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "CumsumInputDtypeInt32Module_basic", + "CumsumModule_basic", + "CumsumStaticModule_basic", + "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", + "DeformConv2D_basic", + "DeterminantModule_F32", + "DeterminantBatchedModule_F32", + "DeterminantDynamicModule_F32", + "DeterminantModule_F32", + "DiagonalModule_basic", + "DiagonalModule_nonsquare", + "DiagonalModule_transposed", + "DiagonalModule_with_dims", + "DiagonalModule_with_dims_and_offset", + "DiagonalModule_with_negative_dims", + "DiagonalModule_with_offset", + "DiagonalWithStaticShapeModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtan2FloatIntModule_basic", + "ElementwiseAtan2FloatIntStaticModule_basic", + "ElementwiseAtan2TensorFloatModule_basic", + "ElementwiseAtan2TensorFloatStaticModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", + "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideScalarModule_basic", + "ElementwiseAtenFloorDivideScalarNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorPositiveModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalNotOpPromoteModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseDivScalarRoundingModeTruncModule_basic", + "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorFloatModule_basic", + "ElementwiseDivTensorIntegerModule_basic", + "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorModule_basic", + "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncModule_basic", + "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorUnsignedIntegerModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseIntTensorLtFloatScalarModule_basic", + "ElementwiseIsinfModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseMulScalarModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorFloatModule_basic", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSgnModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", + "ElementwiseSpecialExpm1IntModule_basic", + "ElementwiseSpecialExpm1Module_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTernaryModule_basic", + "ElementwiseToDtypeF32ToI64Module_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", + "ElementwiseUnaryIntModule_basic", + "ElementwiseUnsqueezeNegDimsModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereSelfModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", + "EqIntModule_basic", + "ExpandAsFloatModule_basic", + "ExpandAsIntModule_basic", + "ExpandModule_basic", + "ExponentialModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat32WithFloat32_basic", + "Fill_TensorFloat32WithFloat64_basic", + "Fill_TensorFloat32WithInt64_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", + "FlattenDynamicModuleCollapseAll_basic", + "FlattenDynamicModule_basic", + "FlattenRank0Module_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "FloatImplicitModule_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "Gather2DInputModdule_basic", + "GatherModule_basic", + "GatherNegativeDimModule_basic", + "GatherRandomIndexModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GeluBackwardModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "HardtanhBackward_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexTensorDyanmicInputContiguousWithNoneModule_basic", + "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", + "IndexTensorSelectDimModule_basic", + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateDynamicModule_scales_recompute_bilinear", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IouOfModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "IscloseStaticModuleTrue_basic", + "IscloseStaticModule_basic", + "LeakyReluBackwardModule_basic", + "LeakyReluBackwardStaticModule_basic", + "LenStrModule_basic", + "LiftFreshCopyModule_basic", + "LinalgNormKeepDimComplexModule_basic", + "LinalgNormModule_basic", + "LinalgVectorNormComplexModule_basic", + "LogSoftmaxBackwardModule_basic", + "LogSoftmaxIntModule_basic", + "MaskedFillTensorFloatValueModule_basic", + "MatmulBroadcastBatchDim_basic", + "MatmulSingleDynamicBatchDim_basic", + "Matmul_2d", + "Matmul_4d", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool1dCeilModeTrueModule_basic", + "MaxPool1dModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", + "MaxPool3dWithIndicesAllNegativeValuesModule_basic", + "MaxPool3dWithIndicesAllOnesModule_basic", + "MaxPool3dWithIndicesCeilModeTrueModule_basic", + "MaxPool3dWithIndicesFullSizeKernelModule_basic", + "MaxPool3dWithIndicesModule_basic", + "MaxPool3dWithIndicesNonDefaultDilationModule_basic", + "MaxPool3dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool3dWithIndicesNonDefaultParamsModule_basic", + "MaxPool3dWithIndicesNonDefaultStrideModule_basic", + "MaxPool3dWithIndicesStaticModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimDtypeModule_basic", + "MeanDimEmptyDimModule_basic", + "MeanDimKeepdimModule_basic", + "MeanDimModule_basic", + "MeanDimNegativeModule_basic", + "MeanDimNoneDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "Mlp1LayerModule_basic", + "Mlp2LayerModuleNoBias_basic", + "Mlp2LayerModule_basic", + "MmModule_basic", + "MmModule_chained", + "MmTanhModule_basic", + "MobilenetV3Module_basic", + "MoveDimIntNegativeIndexModule_basic", + "MseLossMeanReductionModule_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "Mv_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "NativeGroupNormBackwardModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormDynamicModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_basic", + "NllLossModule_ignore_index_out_of_bounds_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NormScalarComplexModule_basic", + "NormScalarModule_basic", + "NormScalarOptDimKeepDimComplexModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NormalFunctionalModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "OneHotModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "PermuteNegativeIndexModule_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsIotaModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedMLP_basic", + "QuantizedNoLayer_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", + "QuantizedSingleLayer_basic", + "RandIntDtypeModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandModule_basic", + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + "RandnModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAllDimBool_basic", + "ReduceAllDimEmpty_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAmaxKeepDim_basic", + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceAnyBoolModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ReduceL1NormComplexModule_basic", + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + "ReduceL2NormComplexModule_basic", + "ReduceL2NormModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimComplexModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceLN3NormModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMaxAlongDim_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxKeepDim_basic", + "ReduceMaxNegativeDim_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDimSignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "ReduceMinAlongDim_basic", + "ReduceMinFloatModule_basic", + "ReduceMinKeepDimReturnBoth_basic", + "ReduceMinKeepDim_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceProdDimIntFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdDtypeIntModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ResNet18Module_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeCollapseModule_basic", + "ReshapeDynamicModule_basic", + "ReshapeExpandModule_basic", + "RollModule_basic", + "RsubIntModule_noalpha_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + # REMOVE WHEN ENABLE_GQA IS ADDED + "ScaledDotProductAttentionBoolMaskModule_basic", + "ScaledDotProductAttentionSameCausalModule_basic", + "ScaledDotProductAttentionSameDynamicModule_basic", + "ScatterAddDynamicModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModule", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModule", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModule", + "ScatterReduceIntSumModuleIncludeSelf", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + "SelectIntModule_basic", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", + "SignAndLogarithmOfDeterminantModule_F32", + "SignAndLogarithmOfDeterminantBatchedModule_F32", + "SignAndLogarithmOfDeterminantDynamicModule_F32", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopy_Module_basic", + "SliceEndSleStartModule_basic", + "SliceModule_basic", + "SliceStaticComplexInputModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceSingleIdxModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SoftmaxBackwardModule_basic", + "SoftmaxIntArgTypeF64Module_basic", + "SoftmaxIntModule_basic", + "SoftmaxIntNegDimModule_basic", + "SoftmaxIntNonNoneDtypeModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SplitWithSizes_Module_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SqueezeDimModule_dynamic", + "SqueezeDimModule_negDim", + "StdBiasedModule_basic", + "StdCorrectionAllDimReduceModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdCorrectionKeepDimModule_basic", + "StdCorrectionLargeInputModule_basic", + "StdCorrectionModule_basic", + "StdCorrectionNoneModule_basic", + "StdCorrectionSingleDimReduceModule_basic", + "StdDimBiasedModule_basic", + "StdDimEmptyDimModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimKeepDimTrueModule_basic", + "StdDimNoneDimModule_basic", + "StdUnbiasedModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TanhBackward_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TensorsConcatModule_basic", + "TensorsConcatNegativeDimModule_basic", + "TensorsConcatPromoteDTypeModule_basic", + "TensorsStackModule_basic", + "TensorsStackNegativeDimModule_basic", + "TensorsStackPromoteDTypeModule_basic", + "TensorsStackSingleElementListModule_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "ToCopyBoolDTypeStaticModule_basic", + "ToCopyModule_basic", + "ToCopyWithDTypeFalsePinMemoryModule_basic", + "ToCopyWithDTypeModule_basic", + "ToDtypeLayoutCPUModule_basic", + "ToDtypeLayoutNoneModule_basic", + "ToDtypeLayoutStridedModule_basic", + "ToDtypeIntFromFloatModule_basic", + "ToDtypeFloatFromIntModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_basic", + "TraceModule_empty", + "TraceModule_nonsquare", + "TraceSignedIntModule_basic", + "TraceUnsignedIntModule_basic", + "TraceUnsignedIntModule_empty", + "TupleModule_basic", + "TypeAsDifferentModule_basic", + "TypeConversionF32ToF64Module_basic", + "TypeConversionF64ToF32Module_basic", + "TypeConversionI1ToF32Module_basic", + "TypeConversionI1ToF64Module_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI1ToI64Module_basic", + "TypeConversionI32ToI64Module_basic", + "TypeConversionI64ToI32Module_basic", + "TypePromotionDifferentCategoryModule_basic", + "TypePromotionSameCategoryDifferentWidthModule_basic", + "TypePromotionZeroRankHigherCategoryModule_basic", + "UniformModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeView1DFoldModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "UnsafeViewDynamicExpandWithAtenSizeIntModule_basic", + "UnsafeViewExpandModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", + "VarBiasedModule_basic", + "VarCorrectionAllDimReduceModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarCorrectionKeepDimModule_basic", + "VarCorrectionLargeInputModule_basic", + "VarCorrectionModule_basic", + "VarCorrectionNoneModule_basic", + "VarCorrectionSingleDimReduceModule_basic", + "VarDimAllDimReduceModule_basic", + "VarDimBiasedModule_basic", + "VarDimEmptyDimModule_basic", + "VarDimModule_basic", + "VarDimMultiDimModule_basic", + "VarDimNegativeModule_basic", + "VarDimNoneDimModule_basic", + "VarDimSingleDimModule_basic", + "VarDimUnbiasedModule_basic", + "VarMeanBiasedModule_basic", + "VarMeanCorrectionModule_basic", + "VarMeanCorrectionNoneModule_basic", + "VarMeanDimBiasedModule_basic", + "VarMeanDimModule_basic", + "VarMeanUnbiasedModule_basic", + "VarUnbiasedModule_basic", + "View1DFoldModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewCollapseModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandCollapseWithAtenIntModule_basic", + "ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic", + "ViewDynamicExpandModule_basic", + "ViewDynamicExpandWithAtenSizeIntModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewFlattenAndExpandModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "ViewSizeDimFollowedByCollapsedOnesModule_basic", + "ViewSizeDimFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", + "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedByCollapsedOnesModule_basic", + "ViewSizeDimLedByExpandedOnesModule_basic", + "ViewSizeFromOtherTensor_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "_Convolution2DAllFalseModule_basic", + "_Convolution2DBenchmarkModule_basic", + "_Convolution2DCudnnModule_basic", + "_Convolution2DDeterministicModule_basic", + "_Convolution2DTF32Module_basic", + "_ConvolutionDeprecated2DAllFalseModule_basic", + "_ConvolutionDeprecated2DBenchmarkModule_basic", + "_ConvolutionDeprecated2DCudnnModule_basic", + "_ConvolutionDeprecated2DDeterministicModule_basic", + "_LogSoftmaxModule_basic", + "_SoftmaxModule_basic", +} + +if torch_version_for_comparison() > version.parse("2.5.1"): + ONNX_TOSA_XFAIL_SET = ONNX_TOSA_XFAIL_SET | { + # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible + # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 + "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", + } diff --git a/projects/pt1/examples/_example_utils.py b/projects/pt1/examples/_example_utils.py new file mode 100644 index 000000000000..8f63b4fd4a63 --- /dev/null +++ b/projects/pt1/examples/_example_utils.py @@ -0,0 +1,52 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from PIL import Image +import requests + +import torch +from torchvision import transforms + + +DEFAULT_IMAGE_URL = ( + "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" +) +DEFAULT_LABEL_URL = ( + "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt" +) + + +def load_and_preprocess_image(url: str = DEFAULT_IMAGE_URL): + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" + } + img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") + # preprocessing pipeline + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + img_preprocessed = preprocess(img) + return torch.unsqueeze(img_preprocessed, 0) + + +def load_labels(url: str = DEFAULT_LABEL_URL): + classes_text = requests.get( + url=url, + stream=True, + ).text + labels = [line.strip() for line in classes_text.splitlines()] + return labels + + +def top3_possibilities(res, labels): + _, indexes = torch.sort(res, descending=True) + percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 + top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] + return top3 diff --git a/projects/pt1/examples/fximporter_resnet18.py b/projects/pt1/examples/fximporter_resnet18.py new file mode 100644 index 000000000000..8776c42fa7e4 --- /dev/null +++ b/projects/pt1/examples/fximporter_resnet18.py @@ -0,0 +1,59 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from pathlib import Path + +import torch +import torch.utils._pytree as pytree +import torchvision.models as models +from torch_mlir import fx +from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend +from torch_mlir_e2e_test.configs.utils import ( + recursively_convert_to_numpy, +) + +sys.path.append(str(Path(__file__).absolute().parent)) +from _example_utils import ( + top3_possibilities, + load_and_preprocess_image, + load_labels, + DEFAULT_IMAGE_URL, +) + + +print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr) +img = load_and_preprocess_image(DEFAULT_IMAGE_URL) +labels = load_labels() + +resnet18 = models.resnet18(pretrained=True).eval() +module = fx.export_and_import( + resnet18, + torch.ones(1, 3, 224, 224), + output_type="linalg-on-tensors", + func_name=resnet18.__class__.__name__, +) +backend = refbackend.RefBackendLinalgOnTensorsBackend() +compiled = backend.compile(module) +fx_module = backend.load(compiled) + +params = { + **dict(resnet18.named_buffers(remove_duplicate=False)), +} +params_flat, params_spec = pytree.tree_flatten(params) +params_flat = list(params_flat) +with torch.no_grad(): + numpy_inputs = recursively_convert_to_numpy(params_flat + [img]) + +golden_prediction = top3_possibilities(resnet18.forward(img), labels) +print("PyTorch prediction") +print(golden_prediction) + +prediction = top3_possibilities( + torch.from_numpy(getattr(fx_module, resnet18.__class__.__name__)(*numpy_inputs)), + labels, +) +print("torch-mlir prediction") +print(prediction) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index 0cc5b5dda96a..ea56653ca6f6 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -4,71 +4,36 @@ # Also available under a BSD-style license. See LICENSE. import sys - -from PIL import Image -import requests +from pathlib import Path import torch import torchvision.models as models -from torchvision import transforms - from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend - -def load_and_preprocess_image(url: str): - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.102 Safari/537.36" - } - img = Image.open(requests.get(url, headers=headers, stream=True).raw).convert("RGB") - # preprocessing pipeline - preprocess = transforms.Compose( - [ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ] - ) - img_preprocessed = preprocess(img) - return torch.unsqueeze(img_preprocessed, 0) - - -def load_labels(): - classes_text = requests.get( - "https://raw.githubusercontent.com/cathyzhyi/ml-data/main/imagenet-classes.txt", - stream=True, - ).text - labels = [line.strip() for line in classes_text.splitlines()] - return labels - - -def top3_possibilities(res): - _, indexes = torch.sort(res, descending=True) - percentage = torch.nn.functional.softmax(res, dim=1)[0] * 100 - top3 = [(labels[idx], percentage[idx].item()) for idx in indexes[0][:3]] - return top3 +sys.path.append(str(Path(__file__).absolute().parent)) +from _example_utils import ( + top3_possibilities, + load_and_preprocess_image, + load_labels, + DEFAULT_IMAGE_URL, +) def predictions(torch_func, jit_func, img, labels): - golden_prediction = top3_possibilities(torch_func(img)) + golden_prediction = top3_possibilities(torch_func(img), labels) print("PyTorch prediction") print(golden_prediction) - prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy()))) + prediction = top3_possibilities(torch.from_numpy(jit_func(img.numpy())), labels) print("torch-mlir prediction") print(prediction) -image_url = ( - "https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg" -) - -print("load image from " + image_url, file=sys.stderr) -img = load_and_preprocess_image(image_url) +print("load image from " + DEFAULT_IMAGE_URL, file=sys.stderr) +img = load_and_preprocess_image(DEFAULT_IMAGE_URL) labels = load_labels() -resnet18 = models.resnet18(pretrained=True) -resnet18.train(False) +resnet18 = models.resnet18(pretrained=True).eval() module = torchscript.compile( resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors" ) diff --git a/projects/pt1/python/test/lit.cfg.py b/projects/pt1/python/test/lit.cfg.py index 0e6d132faa00..ddac7b7dc596 100644 --- a/projects/pt1/python/test/lit.cfg.py +++ b/projects/pt1/python/test/lit.cfg.py @@ -18,6 +18,24 @@ # Configuration file for the 'lit' test runner. + +# Find path to the ASan runtime required for the Python interpreter. +def find_asan_runtime(): + if not "asan" in config.available_features or not "Linux" in config.host_os: + return "" + # Find the asan rt lib + return ( + subprocess.check_output( + [ + config.host_cxx.strip(), + f"-print-file-name=libclang_rt.asan-{config.host_arch}.so", + ] + ) + .decode("utf-8") + .strip() + ) + + # name: The name of this test suite. config.name = "TORCH_MLIR_PYTHON" @@ -37,10 +55,15 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.torch_mlir_obj_root, "test") +# Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux. +# TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms). +if "asan" in config.available_features and "Linux" in config.host_os: + _asan_rt = find_asan_runtime() + config.python_executable = f"env LD_PRELOAD={_asan_rt} {config.python_executable}" # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in # llvm/utils/lit/lit/llvm/config.py. -if "Windows" in config.host_os: +elif "Windows" in config.host_os: config.python_executable = '"%s"' % (config.python_executable) config.substitutions.append(("%PATH%", config.environment["PATH"])) diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index 2c339be987b1..1c202ed3a382 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -65,6 +65,7 @@ def _get_decomposition_table(): aten.sigmoid_backward, aten._native_batch_norm_legit, aten.squeeze, + aten._scaled_dot_product_flash_attention_for_cpu, ] # TODO: enable test once 2.1.0 is stable if torch_version_for_comparison() >= version.parse("2.1.0.dev"): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1cf0c2c7696a..ec9005f784b6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -85,6 +85,26 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) + +def torchvision〇roi_align〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> List[int]: + return [rois[0], input[1], pooled_height, pooled_width] + +def torchvision〇roi_align〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> int: + return input_rank_dtype[1] + +def torchvision〇roi_pool〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[List[int], List[int]]: + output = [rois[0], input[1], pooled_height, pooled_width] + return (output, output) + +def torchvision〇roi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]: + return (input_rank_dtype[1], torch.int64) + +def torchvision〇nms〡shape(dets: List[int], scores: List[int], iou_threshold: float) -> List[int]: + return [hacky_get_unknown_dimension_size(), len(dets)] + +def torchvision〇nms〡dtype(dets_rank_dtype: Tuple[int, int], scores_rank_dtype: Tuple[int, int], iou_threshold: float) -> int: + return torch.int + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`. @@ -118,6 +138,24 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇fake_quantize_per_tensor_affine_cachemask〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]: + return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self)) + +def aten〇fake_quantize_per_tensor_affine〇tensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], quant_min: int, quant_max: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇_fake_quantize_per_tensor_affine_cachemask_tensor_qparams〡shape(self: List[int], scale: List[int], zero_point: List[int], fake_quant_enabled: List[int], quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]: + return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self)) + +def aten〇fake_quantize_per_channel_affine〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇fake_quantize_per_channel_affine_cachemask〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]: + return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self)) + +def aten〇rad2deg〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -178,9 +216,18 @@ def aten〇silu〡shape(self: List[int]) -> List[int]: def aten〇exp〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇exp2〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇special_expm1〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇isfinite〡shape(self: List[int]) -> List[int]: + return self + def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]: broadcast = upstream_shape_functions.broadcast(x1, x2) return broadcast[:dim] + broadcast[dim + 1:] @@ -206,6 +253,27 @@ def aten〇sign〡shape(self: List[int]) -> List[int]: def aten〇sgn〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇linalg_det〡shape(A: List[int]) -> List[int]: + assert len(A) == 2 or len(A) == 3 + assert A[-1] == A[-2] + if len(A) == 3: + return A[:1] + return upstream_shape_functions.zero_dim_tensor(A) + +def aten〇_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List[int]]: + return (aten〇linalg_det〡shape(A), A, A[:-1]) + +def aten〇_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]: + return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1]) + +def aten〇linalg_slogdet〡shape(A: List[int]) -> Tuple[List[int], List[int]]: + assert len(A) == 2 or len(A) == 3 + assert A[-1] == A[-2] + if len(A) == 3: + return A[:1], A[:1] + shape = upstream_shape_functions.zero_dim_tensor(A) + return shape, shape + def aten〇detach〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -239,6 +307,9 @@ def aten〇gelu_backward〡shape(grad_output: List[int], self: List[int], approx def aten〇leaky_relu_backward〡shape(grad_output: List[int], self: List[int], negative_slope: float, self_is_result: bool) -> List[int]: return upstream_shape_functions.unary(grad_output) +def aten〇rrelu_with_noise_backward〡shape(grad_output: List[int], self: List[int], noise: List[int], lower: float, upper: float, training: bool, self_is_result: bool) -> List[int]: + return upstream_shape_functions.unary(grad_output) + def aten〇hardtanh_backward〡shape(grad_output: List[int], self: List[int], min_val: float, max_val: float) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -254,12 +325,18 @@ def aten〇log〡shape(self: List[int]) -> List[int]: def aten〇log_sigmoid〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇hann_window〇periodic〡shape(window_length: int, periodic: bool, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return [window_length] + def aten〇hardshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: return upstream_shape_functions.unary(self) def aten〇softshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇polar〡shape(abs: List[int], angle: List[int]) -> List[int]: + return upstream_shape_functions.unary(abs) + def aten〇mish〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -283,6 +360,9 @@ def aten〇glu〡shape(self: List[int], dim: int = -1) -> List[int]: def aten〇_softmax〡shape(self: List[int], dim: int, half_to_float: bool) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇_safe_softmax〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇softmax〇int〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -468,6 +548,14 @@ def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1 assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}" return upstream_shape_functions.broadcast(self, other) +@check_shape_function([ + Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=True), # keep dim, + Invocation(TensorOfShape(2, 4, 3, device="cpu"), k=2, dim=1, keepdim=False), # don't keep dim +]) +def aten〇kthvalue〡shape(self: List[int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[List[int], List[int]]: + new_shape = upstream_shape_functions.argmax(self, dim, keepdim) + return (new_shape, new_shape) + def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -555,6 +643,15 @@ def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇rrelu_with_noise_functional〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[List[int], List[int]]: + return upstream_shape_functions.unary(self), upstream_shape_functions.unary(noise) + def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -597,7 +694,7 @@ def aten〇mean〡shape(self: List[int], dtype: Optional[int] = None) -> List[in def aten〇var〡shape(self: List[int], unbiased: bool = True) -> List[int]: return [] -def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> List[int]: +def prims〇var〡shape(inp: List[int], dims: Optional[List[int]], correction: Optional[float] = 1, output_dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(inp, dims, False, None) def aten〇var〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased: bool = True, keepdim: bool = False) -> List[int]: @@ -639,8 +736,19 @@ def aten〇trace〡shape(self: List[int]) -> List[int]: assert len(self) == 2, "input must have rank 2" return [] +# TODO: replace this patched function with `upstream_shape_functions.argmax` when upstream fix it +# see https://github.com/pytorch/pytorch/pull/129838 +def patched_argmax_shape_func(self: List[int], dim: Optional[int] = None, keepdim: bool = False): + if dim is None and keepdim: + out: List[int] = [] + for i in self: + out.append(1) + return out + return upstream_shape_functions.argmax(self, dim, keepdim) + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`. Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`. @@ -649,11 +757,11 @@ def aten〇trace〡shape(self: List[int]) -> List[int]: ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. ]) def aten〇argmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: - return upstream_shape_functions.argmax(self, dim, keepdim) + return patched_argmax_shape_func(self, dim, keepdim) def aten〇argmin〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> List[int]: # There is no shape function for argmin in pytorch, but the one for argmax does exactly what is needed here. - return upstream_shape_functions.argmax(self, dim, keepdim) + return patched_argmax_shape_func(self, dim, keepdim) # TODO: The result shape when num_classes=-1 depends on the runtime values of the input tensor, # making it impossible to add support for it using the current design of the shape library. @@ -678,12 +786,32 @@ def aten〇min〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) - def aten〇amax〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +def aten〇amin〡shape(self: List[int], dim: List[int] = (), keepdim: bool = False) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) + +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), keepdim=True), # `keepdim`. + Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. + Invocation(TensorOfShape(2, 3, 4), dim=0, keepdim=True), # `keepdim`. + Invocation(TensorOfShape(2, 3, 4), dim=-3), # Negative `dim`. + Invocation(TensorOfShape(2, 3, 4), dim=2), # Maximum valid `dim`. + ErrorInvocation(TensorOfShape(2, 3, 4), dim=-4), # `dim` out of bounds. + ErrorInvocation(TensorOfShape(2, 3, 4), dim=3), # `dim` out of bounds. +]) +def aten〇aminmax〡shape(self: List[int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[List[int], List[int]]: + reduced_shape = patched_argmax_shape_func(self, dim, keepdim) + return reduced_shape, reduced_shape + def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) +def prims〇sum〡shape(inp: List[int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(inp, dims, False, output_dtype) + def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype) @@ -724,6 +852,13 @@ def aten〇numpy_T〡shape(self: List[int]) -> List[int]: result_shape.insert(0, i) return result_shape +def aten〇outer〡shape(self: List[int], vec2: List[int]) -> List[int]: + return [self[0], vec2[0]] + +@check_shape_function([Invocation(TensorOfShape(3), TensorOfShape(3))]) +def aten〇dot〡shape(self: List[int], tensor: List[int]) -> List[int]: + return [] + def aten〇matmul〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.matmul(self, other) @@ -733,6 +868,9 @@ def aten〇mv〡shape(self: List[int], vec: List[int]) -> List[int]: def aten〇mm〡shape(self: List[int], mat2: List[int]) -> List[int]: return upstream_shape_functions.mm(self, mat2) +def aten〇_int_mm〡shape(self: List[int], mat2: List[int]) -> List[int]: + return upstream_shape_functions.mm(self, mat2) + def aten〇addmm〡shape(self: List[int], mat1: List[int], mat2: List[int], beta: float = 1, alpha: float = 1) -> List[int]: return upstream_shape_functions.addmm(self, mat1, mat2, beta, alpha) @@ -938,9 +1076,85 @@ def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[in def aten〇max_pool2d_with_indices_backward〡shape(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: return self +def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]: + maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode) + return maxpool3d, indices + +def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: + assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5" + assert (len(output_size) == 3), "output_size must have 3 elements" + assert (len(self) == len(indices)), "Input and indices must be of the same rank" + if len(self) == 5: + return [self[0], self[1], output_size[0], output_size[1], output_size[2]] + else: + return [self[0], output_size[0], output_size[1], output_size[2]] + def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size +# TODO: This should be upstreamed. +# See https://github.com/pytorch/pytorch/pull/76889 for an example. +def avg_pool3d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]): + assert ( + len(kernel_size) == 1 or len(kernel_size) == 3 + ), "max_pool3d: kernel_size must either be a single int, or a tuple of three ints" + (kD, kH, kW) = (kernel_size[0], kernel_size[0], kernel_size[0]) if len(kernel_size) == 1 else (kernel_size[0], kernel_size[1], kernel_size[2]) + + assert ( + len(stride) == 0 or len(stride) == 1 or len(stride) == 3 + ), "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints" + + if len(stride) == 0: + (dD, dH, dW) = (kD, kD, kD) + elif len(stride) == 1: + (dD, dH, dW) = (stride[0], stride[0], stride[0]) + else: # len(stride) == 3 + (dD, dH, dW) = (stride[0], stride[1], stride[2]) + + assert ( + len(padding) == 1 or len(padding) == 3 + ), "max_pool3d: padding must either be a single int, or a tuple of thee ints" + (padD, padH, padW) = (padding[0], padding[0], padding[0]) if len(padding) == 1 else (padding[0], padding[1], padding[2]) + + dilationD = 1 + dilationH = 1 + dilationW = 1 + + assert len(input) == 4 or len(input) == 5 + nbatch = input[-5] if len(input) == 5 else 1 + nInputPlane = input[-4] + inputDepth = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputDepth = upstream_shape_functions.pooling_output_shape(inputDepth, kD, padD, dD, dilationD, ceil_mode) + outputHeight = upstream_shape_functions.pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = upstream_shape_functions.pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + _pool3d_shape_check( + input, + kD, + kH, + kW, + dD, + dH, + dW, + padD, + padH, + padW, + dilationD, + dilationH, + dilationW, + outputDepth, + outputHeight, + outputWidth, + ) + + if len(input) == 4: + return [nInputPlane, outputDepth, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputDepth, outputHeight, outputWidth] + # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]): @@ -1041,6 +1255,9 @@ def aten〇adaptive_avg_pool1d〡shape(self: List[int], output_size: List[int]) def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: return avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) +def aten〇avg_pool3d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> List[int]: + return avg_pool3d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override) + def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]: return upstream_shape_functions.adaptive_avg_pool2d(self, output_size) @@ -1090,11 +1307,49 @@ def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: return upstream_shape_functions.linear(input, weight, bias) +@check_shape_function([ + Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [], [], [], [], 0), # Basic case + Invocation(TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), [1], [0], [0], [], 2), # Expansions w/ Non-Zero unroll_dim + Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [1, 2], [1, 2], [1, 2], 0), # Multiple expansions + Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [2, 1], [1, 2], [1, 2], 0), # Unordered expansion + ErrorInvocation(TensorOfShape(4, 5, 1), TensorOfShape(4, 5, 3), TensorOfShape(1, 5, 3), [], [], [0], [2], 0), # Num dimensions don't match +]) +def aten〇_trilinear〡shape(i1: List[int], i2: List[int], i3: List[int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> List[int]: + total_dims = len(i1) + len(expand1) + + assert unroll_dim >= 0 and unroll_dim < total_dims, f"unroll_dim must be in [0, {total_dims - 1}]" + + i1_copy = upstream_shape_functions._copy(i1) + i2_copy = upstream_shape_functions._copy(i2) + i3_copy = upstream_shape_functions._copy(i3) + + # Expand dimensions based on args + inputs = [i1_copy, i2_copy, i3_copy] + expands = [expand1, expand2, expand3] + for index, expand in enumerate(expands): + size = len(inputs[index]) + for dim in expand: + assert dim <= size, f"expand dimension {dim} is out of bounds for input of shape {inputs[index]}" + inputs[index].insert(dim, 1) + + assert len(i1_copy) == len(i2_copy) == len(i3_copy), "number of dimensions must match" + + output_shape = upstream_shape_functions.broadcast_three(i1_copy, i2_copy, i3_copy) + sumdim_bools = [False] * len(output_shape) + for dim in sumdim: + sumdim_bools[dim] = True + + for i in range(len(output_shape) - 1, -1, -1): + if sumdim_bools[i]: + output_shape = upstream_shape_functions._reduce_along_dim(output_shape, i, False) + + return output_shape + @check_shape_function([ Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape ]) -def aten〇scaled_dot_product_attention〡shape(query: List[int], key: List[int], value: List[int], attn_mask: Optional[List[int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> List[int]: +def aten〇scaled_dot_product_attention〡shape(query: List[int], key: List[int], value: List[int], attn_mask: Optional[List[int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False) -> List[int]: outshape = query outshape[-1] = value[-1] return outshape @@ -1162,6 +1417,25 @@ def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: L def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]: return _diag_embed_shape_helper(self, offset, dim1, dim2) +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(5, 3, 4), k = 5, dims=(1, 2,)), # multiple times rotation + Invocation(TensorOfShape(3, 5, 2), k = -2), # neagtive direction, remainder=2 + Invocation(TensorOfShape(7, 2, 6, 3), k = -5), # neagtive direction, remainder=3 + ErrorInvocation(TensorOfShape(2, 3, 4), dims=(0,)), # total lenght of the dims is < 2 + ErrorInvocation(TensorOfShape(2)), # the input is one-dimensional +]) +def aten〇rot90〡shape(self: List[int], k: int = 1, dims: List[int] = (0, 1,)) -> List[int]: + assert len(self) >= 2, "expected total dims >= 2 but got {}".format(len(self)) + assert len(dims) == 2, "expected total rotation dims == 2, but got dims = {}".format(len(dims)) + + k = (k % 4 + 4) % 4 # equal to k % 4, but 'k % 4' cannot handle negative values for k. + + if k == 1 or k == 3: + self[dims[0]], self[dims[1]] = self[dims[1]], self[dims[0]] + + return self + def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]: return upstream_shape_functions.unary(self) @@ -1208,9 +1482,23 @@ def aten〇_index_put_impl〡shape(self: List[int], indices: List[Optional[List[ def aten〇bernoulli〡shape(self: List[int], generator: Any = None) -> List[int]: return self +@check_shape_function([ + Invocation(TensorOfShape(5), num_samples=3), # Vector + Invocation(TensorOfShape(4, 5), num_samples=3), # Matrix +]) +def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: bool = False, generator: Any = None) -> List[int]: + assert len(self) == 1 or len(self) == 2 + if len(self) == 1: + return [num_samples] + num_rows = self[0] + return [num_rows, num_samples] + def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: return self +def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return self + def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self @@ -1274,6 +1562,18 @@ def aten〇floor_divide〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇atan2〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇frac〡shape(self: List[int]) -> List[int]: + return self + +def aten〇signbit〡shape(self: List[int]) -> List[int]: + return self + +def aten〇ldexp〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + +def aten〇copysign〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇__and__〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1286,6 +1586,12 @@ def aten〇minimum〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇maximum〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) +def aten〇fmin〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + +def aten〇fmax〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.broadcast(self, other) + def aten〇bitwise_or〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1425,6 +1731,45 @@ def aten〇addcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[in def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2)) +@check_shape_function([ + Invocation(TensorOfShape(1,5,5), [5,5], [1,5], [1,1], [0,0], [1,1]), # basic case + Invocation(TensorOfShape(1,4,5), [6,6], [2,2], [1,5], [0,0], [1,1]), # dilation + Invocation(TensorOfShape(1,5,15), [5,5], [1,5], [1,1], [0,1], [1,1]), # padding + Invocation(TensorOfShape(1,9,4), [5,5], [3,3], [1,1], [0,0], [2,2]), # stride + ErrorInvocation(TensorOfShape(1,5,5), [5,5], [1,7], [1,1], [0,0], [1,1]), # mismatch of sliding blocks +]) +def aten〇col2im〡shape(self: List[int], output_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int]) -> List[int]: + ndim = len(self) + assert (ndim == 2 and self[0] != 0 and self[1] != 0) or (ndim == 3 and self[1] != 0 and self[2] != 0), "Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non zero dimensions for input" + + assert len(output_size) == 2, "output_size is expected to have length 2" + assert len(kernel_size) == 2, "kernel_size is expected to have length 2" + assert len(dilation) == 2, "dilation is expected to have length 2" + assert len(stride) == 2, "stride is expected to have length 2" + assert len(padding) == 2, "padding is expected to have length 2" + + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size should be greater than 0" + assert dilation[0] > 0 and dilation[1] > 0, "dilation should be greater than 0" + assert padding[0] >= 0 and padding[1] >= 0, "padding should be non negative" + assert stride[0] > 0 and stride[1] > 0, "stride must be greater than 0" + + batch_dim = 0 if ndim == 3 else -1 + n_input_plane = self[batch_dim + 1] + + assert n_input_plane % (kernel_size[0] * kernel_size[1]) == 0, "Expected size of input's dimension 1 to be divisible by the product of kernel_size" + + input_length = self[batch_dim + 2] + n_blocks_height = (output_size[0] + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[0] + 1 + n_blocks_width = (output_size[1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 + + assert input_length == n_blocks_height * n_blocks_width, "Expected size of input's dimension 2 to match the calculated number of sliding blocks" + + # compute the shape of the output + num_channels = n_input_plane // (kernel_size[0] * kernel_size[1]) + out: List[int] = ([self[0], num_channels] if batch_dim == 0 else [num_channels]) + [elem for elem in output_size] + + return out + @check_shape_function([ Invocation(TensorOfShape(2, 3), 1), # Basic case. Invocation(TensorOfShape(2, 3), 2, dim=0), # Test explicit `dim`. @@ -1494,12 +1839,41 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: assert False, "Unsupported dtype" +def torchvision〇deform_conv2d〡shape(input: List[int], weight: List[int], offset: List[int], mask: List[int], bias: List[int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> List[int]: + return [input[0], weight[0], offset[2], offset[3]] + +def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], offset_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> int: + return input_rank_dtype[1] + def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def _conv_padding(weight: List[int], dilation: List[int], padding: str): + rank = len(weight) + # first 2 dimensions of weight corresponds to out_channels and in_channels/groups + num_unpadded_dims = 2 + assert rank > num_unpadded_dims, "conv: weight must be at least 3 dimensional." + num_kernel_elems = rank - num_unpadded_dims + padding_int = [0] * num_kernel_elems + if padding == "same": + for d, i in zip( + dilation, range(num_kernel_elems - 1, -1, -1) + ): + padding_val = d * (weight[num_unpadded_dims+i] - 1) + padding_int[i] = padding_val // 2 + return padding_int + +def aten〇conv2d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: str = "valid", dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: + padding_int = _conv_padding(weight, dilation, padding) + return upstream_shape_functions.conv2d(input, weight, bias, stride, padding_int, dilation, groups) + def aten〇conv3d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv3d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv3d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: str = "valid", dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: + padding_int = _conv_padding(weight, dilation, padding) + return upstream_shape_functions.conv3d(input, weight, bias, stride, padding_int, dilation, groups) + def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]: return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) @@ -1541,6 +1915,16 @@ def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Option def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) +def aten〇conv1d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: str = "valid", dilation: List[int] = (1,), groups: int = 1) -> List[int]: + padding_int = _conv_padding(weight, dilation, padding) + return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding_int, dilation, transposed=False, output_padding=[], groups=1) + +def aten〇conv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]: + return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) + +def aten〇conv_transpose3d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> List[int]: + return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) + def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) @@ -1565,8 +1949,40 @@ def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: return upstream_shape_functions.unary(input) +def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int = 0) -> Tuple[List[int], List[int]]: + return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: - return upstream_shape_functions.slice(self, dim, start, end, step) + start_val = start if start is not None else 0 + end_val = end if end is not None else upstream_shape_functions.max_int() + if (step < 0): + # Convert to equivalent postive-step parameters, which will require swapping start and end. + # If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then + # swapped_end = start + 1 and swapped_begin = end + 1. + # The shift of inclusion can cause issues if these parameters are not already resolved on the left. + # e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we + # should have swapped_end = d, but adding 1 to start before making it valid would result in an + # incorrect, but "valid", swapped_end = 0 for forward slicing. + # Additionally, if adding d doesn't make these values positive, but adding twice would, we need + # to clamp after resolving, otherwise the upstream function will try to resolve a second time. + if start_val < 0: + start_val += self[dim] + if start_val < 0: + start_val = 0 + if end_val < 0: + end_val += self[dim] + if end_val < 0: + end_val = -1 + + tmp = end_val + 1 + end_val = start_val + 1 + start_val = tmp + step = -step + return upstream_shape_functions.slice(self,dim,start_val,end_val,step) + + +def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: + return size def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]: return self, self @@ -1604,6 +2020,9 @@ def aten〇scatter〇src〡shape(self: List[int], dim: int, index: List[int], sr def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], value: float) -> List[int]: return self +def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: + return self + def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) @@ -1624,6 +2043,64 @@ def aten〇_embedding_bag〡shape(weight: List[int], indices: List[int], offsets return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode, per_sample_weights, padding_idx) +@check_shape_function([ + Invocation(4, 3, 1), # Basic case. + Invocation(0, 0, 0), # All zeros case. + Invocation(7, 5, -2), # Negative offset case. + Invocation(35, 55, 16), # Largere values case. +]) +def aten〇triu_indices〡shape(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + if row == 0 or col == 0: + return [2, 0] + + # _get_tril_indices + offset_tril = offset - 1 + if row == 0 or col == 0: + trapezoid_size_tril = 0 + rectangle_size_tril = 0 + else: + m_first_row = min(col, 1 + offset_tril) if offset_tril > 0 else int(row + offset_tril > 0) + m_last_row = max(0, min(col, row + offset_tril)) + n_row_all = max(0, min(row, row + offset_tril)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size_tril = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size_tril = max(0, diff_row * col) + + # Number of elements in bottom trapezoid + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + + return [2, triu_size] + +@check_shape_function([ + Invocation(4, 3, 1), # Basic case. + Invocation(0, 0, 0), # All zeros case. + Invocation(7, 5, -2), # Negative offset case. + Invocation(35, 55, 16), # Largere values case. +]) +def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + if row == 0 or col == 0: + return [2, 0] + + m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) + m_last_row = max(0, min(col, row + offset)) + n_row_all = max(0, min(row, row + offset)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size = max(0, diff_row * col) + + return [2, trapezoid_size + rectangle_size] + +def aten〇deg2rad〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -1642,9 +2119,22 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = return upstream_shape_functions.unary(self) return [] +def aten〇l1_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]: + if reduction == 0: + return upstream_shape_functions.unary(self) + return [] + def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) +def aten〇binary_cross_entropy_with_logits〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, pos_weight: Optional[List[int]] = None, reduction: int = 1) -> List[int]: + scalar_shape: List[int] = [] + if reduction == 0: + result_shape = upstream_shape_functions._copy(self) + else: + result_shape = scalar_shape + return result_shape + @check_shape_function([ Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ]) @@ -1666,12 +2156,14 @@ def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. -def pad_shape_fn(input: List[int], pad: List[int]): +def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False): assert len(pad) % 2 == 0, "Must have paired low-high pad amount values" assert len(pad) // 2 <= len(input), "Number of padded dimensions must be less than or equal to the input dimension" # The `pad` list takes the form of Low-high pairs starting at the # *rightmost* dimension of `self`. for i in range(len(pad) // 2): + if validate_pad: + assert pad[2*i] < input[-(i+1)] and pad[2 * i + 1] < input[-(i+1)] input[-(i + 1)] += pad[2 * i] + pad[2 * i + 1] return input @@ -1706,11 +2198,7 @@ def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - hdim = self[-1] - padding_left = padding[0] - padding_right = padding[1] - assert padding_left < hdim and padding_right < hdim - return pad_shape_fn(self, padding) + return pad_shape_fn(self, padding, validate_pad=True) # Padding size must be smaller than corresponding dimension @@ -1723,18 +2211,21 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - vdim = self[-2] - hdim = self[-1] - assert len(padding) == 4, 'padding size expected to be 4' - padding_left = padding[0] - padding_right = padding[1] - padding_top = padding[2] - padding_bottom = padding[3] - assert padding_left < hdim and padding_right < hdim - assert padding_top < vdim and padding_bottom < vdim + return pad_shape_fn(self, padding, validate_pad=True) - return pad_shape_fn(self, padding) +# Padding size must be smaller than corresponding dimension +@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,2,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1,1,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1]), + Invocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,1,1,1,2]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,2,1,1,1])]) +def aten〇reflection_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 3 + assert len(padding) == 6, 'padding size expected to be 6' + return pad_shape_fn(self, padding, validate_pad=True) # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: @@ -1802,12 +2293,98 @@ def aten〇index〇Tensor_hacked_twin〡shape(self: List[int], indices: List[Lis def aten〇cat〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.cat(tensors, dim) +def aten〇atleast_1d〡shape(self: List[int]) -> List[int]: + if len(self) == 0: + return [1] + else: + return self + +def aten〇atleast_2d〡shape(self: List[int]) -> List[int]: + if len(self) == 0: + return [1, 1] + elif len(self) == 1: + x = self[0] + return [1, x] + else: + return self + def aten〇stack〡shape(tensors: List[List[int]], dim: int = 0) -> List[int]: return upstream_shape_functions.stack(tensors, dim) + +@check_shape_function([ + Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case. +]) +def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]: + + tensors_atleast1d = [aten〇atleast_1d〡shape(tensor) for tensor in tensors] + + if len(tensors_atleast1d[0]) == 1: + return upstream_shape_functions.cat(tensors_atleast1d, dim=0) + + return upstream_shape_functions.cat(tensors_atleast1d, dim=1) + +@check_shape_function([ + Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case. +]) +def aten〇column_stack〡shape(tensors: List[List[int]]) -> List[int]: + tensors2d: List[List[int]] = [] + for tensor in tensors: + if len(tensor) == 0: + tensor = [1, 1] + elif len(tensor) == 1: + tensor.append(1) + tensors2d.append(tensor) + + return upstream_shape_functions.cat(tensors2d, dim=1) + def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self +@check_shape_function([ + Invocation(TensorOfShape(3, 9, 5), None, -2, None) # Second-last dim +]) +def aten〇fft_rfft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: + dim = (dim + len(self)) if dim < 0 else dim + assert dim >= 0 and dim < len(self), "Expected dim in [-rank, rank-1]" + out: List[int] = [] + for s in self: + out.append(s) + out[dim] = self[dim] // 2 + 1 + return out + +@check_shape_function([ + Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window. +]) +def aten〇stft〡shape(self: List[int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window: Optional[List[int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, align_to_window: Optional[bool] = None) -> List[int]: + assert len(self) == 1 or len(self) == 2, "Expected input tensor to be of shape (B?,L), where B is an optional batch dimension" + + batch = None if len(self) == 1 else self[0] + length = self[0] if len(self) == 1 else self[1] + hop_length = (n_fft // 4) if hop_length is None else hop_length + assert n_fft > 0 and n_fft <= length, "Expected that 0 < n_fft <= len" + assert hop_length > 0, "Expected hop_length to be greater than 0" + + out: List[int] = [] + if batch is not None: + out.append(batch) # (B?,) + + if onesided is None or onesided == True: + out.append(n_fft//2 + 1) + else: + out.append(n_fft) # (B?,N,) + + # For this operator, center=False by default + out.append(1 + (length - n_fft)//hop_length) #(B?,N,T,) + + if return_complex is not None and bool(return_complex) == False: + out.append(2) # a length-2 dimension of real and imaginary components. This gives output shape (B?,N,T,C?). + + return out + +def aten〇fft_ifft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: + return self + class DummyClassType: def __init__(self): pass @@ -1868,15 +2445,69 @@ def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇renorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]: + return self + def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, None, False, None) def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10), [11]) +]) +def aten〇upsample_nearest1d〡shape(self: List[int], output_size: List[int], scales: Optional[float] = None) -> List[int]: + return [self[0], self[1], output_size[0]] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10), [11], None), + Invocation(TensorOfShape(1, 3, 10), None, [2.0]), + Invocation(TensorOfShape(1, 3, 5), None, [2.5]) +]) +def aten〇upsample_nearest1d〇vec〡shape(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]: + assert output_size is None or scale_factors is None + assert not (output_size is None and scale_factors is None) + if output_size is not None: + return [input[0], input[1], output_size[0]] + else: + assert scale_factors is not None + return [input[0], input[1], int(input[2] * scale_factors[0])] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12]) +]) def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return [self[0], self[1], output_size[0], output_size[1]] +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], None), + Invocation(TensorOfShape(1, 3, 10, 9), None, [2.0, 2.3]), + Invocation(TensorOfShape(1, 3, 5, 6), None, [2.5, 1.0]) +]) +def aten〇upsample_nearest2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> List[int]: + assert output_size is None or scale_factors is None + assert not (output_size is None and scale_factors is None) + if output_size is not None: + return [input[0], input[1], output_size[0], output_size[1]] + else: + assert scale_factors is not None + return [input[0], input[1], int(input[2] * scale_factors[0]), int(input[3] * scale_factors[1])] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True) +]) +def aten〇upsample_bilinear2d〡shape(self: List[int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: + return [self[0], self[1], output_size[0], output_size[1]] + +@check_shape_function([ + Invocation(TensorOfShape(1, 3, 10, 10), [11, 12], True, None), + Invocation(TensorOfShape(1, 3, 10, 9), None, True, [2.0, 2.3]), + Invocation(TensorOfShape(1, 3, 5, 6), None, True, [2.5, 1.0]) +]) +def aten〇upsample_bilinear2d〇vec〡shape(input: List[int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> List[int]: + return aten〇upsample_nearest2d〇vec〡shape(input, output_size, scale_factors) + # ============================================================================== # Dtype Functions # ============================================================================== @@ -2019,13 +2650,51 @@ def prims〇split_dim〡dtype(a_rank_dtype: Tuple[int, int], dim: int, outer_len return a_dtype # note: fake_quantize_per_tensor_affine doesn't support "meta" device, use "cpu" instead. -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu", scale=0.1, zero_point=0, quant_min=0, quant_max=255, error_types={torch.complex128, torch.complex64, torch.bfloat16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool})) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu", scale=0.1, zero_point=0, quant_min=0, quant_max=255, error_types={torch.complex128, torch.complex64, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool})) def aten〇fake_quantize_per_tensor_affine〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) + return self_dtype + +# note: fake_quantize_per_tensor_affine doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu", scale=0.1, zero_point=0, quant_min=0, quant_max=255, error_types={torch.complex128, torch.complex64, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool})) +def aten〇fake_quantize_per_tensor_affine_cachemask〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) + return (self_rank_dtype[1], torch.bool) + +# note: fake_quantize_per_tensor_affine.tensor_qparams doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), 0, 255) for dtype in [torch.float64, torch.float32, torch.float16]) +def aten〇fake_quantize_per_tensor_affine〇tensor_qparams〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], quant_min: int, quant_max: int) -> int: self_rank, self_dtype = self_rank_dtype assert is_float_dtype(self_dtype) assert self_dtype != torch.bfloat16 return self_dtype +# note: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(1, dtype=torch.float32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), TensorOfShape(1, dtype=torch.int32, device="cpu"), 0, 255) for dtype in [torch.float64, torch.float32, torch.float16]) +def aten〇_fake_quantize_per_tensor_affine_cachemask_tensor_qparams〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], fake_quant_enabled_rank_dtype: Tuple[int, int], quant_min: int, quant_max: int) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) + assert self_dtype != torch.bfloat16 + return (self_rank_dtype[1], torch.bool) + +# note: fake_quantize_per_channel_affine doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16]) +def aten〇fake_quantize_per_channel_affine〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) + assert self_dtype != torch.bfloat16 + return self_dtype + +# note: fake_quantize_per_channel_affine_cachemask doesn't support "meta" device, use "cpu" instead. +@check_dtype_function(Invocation(TensorOfShape(3, 3, dtype=dtype, device="cpu"), TensorOfShape(3, dtype=torch.float32, device="cpu"), TensorOfShape(3, dtype=torch.int32, device="cpu"), 0, 0, 255) for dtype in [torch.float64, torch.float32, torch.float16]) +def aten〇fake_quantize_per_channel_affine_cachemask〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int, quant_min: int, quant_max: int) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) + assert self_dtype != torch.bfloat16 + return (self_rank_dtype[1], torch.bool) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2048,11 +2717,30 @@ def aten〇exp〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇exp2〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128})) +def aten〇rad2deg〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert is_integer_dtype(self_dtype) or is_float_dtype(self_dtype) + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2124,6 +2812,15 @@ def aten〇log_sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: assert not self_dtype == torch.bool return self_dtype +@check_dtype_function([Invocation(10, False), Invocation(10, True), + Invocation(10, False, dtype=torch.float32), Invocation(10, True, dtype=torch.float32), + Invocation(10, False, dtype=torch.float64), Invocation(10, True, dtype=torch.float64)]) +def aten〇hann_window〇periodic〡dtype(window_length: int, periodic: bool, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + result_dtype = torch.float32 if dtype is None else dtype + assert is_float_dtype(result_dtype) + return result_dtype + + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5)) def aten〇hardshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int, float, complex] = 0.5) -> int: self_rank, self_dtype = self_rank_dtype @@ -2136,6 +2833,17 @@ def aten〇softshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) + +def aten〇polar〡dtype(abs_rank_dtype: Tuple[int, int], angle_rank_dtype: Tuple[int, int]) -> int: + _, abs_dtype = abs_rank_dtype + _, angle_dtype = angle_rank_dtype + assert (abs_dtype == angle_dtype) + if abs_dtype == torch.float64: + return torch.complex128 + elif abs_dtype == torch.float32: + return torch.complex64 + return abs_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: self_rank, self_dtype = self_rank_dtype @@ -2177,6 +2885,15 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def prims〇sum〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> int: + # When invoking prims.sum() with the output_dtype argument, pytorch + # complains that the argument is not known. + # See https://github.com/pytorch/pytorch/issues/102610 + assert output_dtype is None + inp_rank, inp_dtype = inp_rank_dtype + return inp_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2196,9 +2913,10 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2])) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8})) def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.uint8 return self_dtype @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) @@ -2211,18 +2929,20 @@ def aten〇adaptive_avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], output_ self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2], error_types={torch.uint8})) def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.uint8 return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2])) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2], error_types={torch.uint8})) def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.uint8 return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype( - tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) +# @check_dtype_function(_check_tensors_with_the_same_dtype( +# tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], tensor_device="cpu", error_types={torch.complex128}, training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: input_rank, input_dtype = input_rank_dtype return input_dtype @@ -2244,6 +2964,20 @@ def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_ input_rank, input_dtype = input_rank_dtype return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={*all_integer_dtypes()})) +def aten〇_weight_norm_interface〡dtype(v_rank_dtype: Tuple[int, int], g_rank_dtype: Tuple[int, int], dim: int = 0) -> Tuple[int, int]: + v_rank, v_dtype = v_rank_dtype + g_rank, g_dtype = g_rank_dtype + assert v_dtype == g_dtype + assert not is_integer_dtype(g_dtype) + if g_dtype == torch.complex128: + return v_dtype, torch.float64 + elif g_dtype == torch.complex64: + return v_dtype, torch.float32 + elif g_dtype == torch.bfloat16: + return v_dtype, torch.float32 + return v_dtype, g_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype @@ -2259,6 +2993,10 @@ def aten〇bernoulli〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], p_rank_d self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function([Invocation(TensorOfShape(5, dtype=dtype), 3) for dtype in _SORTED_TORCH_TYPES]) +def aten〇multinomial〡dtype(self_rank_dtype: Tuple[int, int], num_samples: int, replacement: bool = False, generator: Any = None) -> int: + return torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bitwise_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2397,11 +3135,29 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt return torch.int64 return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes()})) +def aten〇linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = A_rank_dtype + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False)) def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int: input_rank, input_dtype = input_rank_dtype @@ -2565,6 +3321,15 @@ def aten〇leaky_relu_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +@check_dtype_function([Invocation(TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), TensorOfShape(3, 3, dtype=dtype), 0.1, 0.9, False, False) for dtype in _SORTED_TORCH_TYPES]) +def aten〇rrelu_with_noise_backward〡dtype(grad_output_rank_dtype: Tuple[int, int], self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex], upper: Union[int, float, complex], training: bool, self_is_result: bool) -> int: + grad_output_rank, grad_output_dtype = grad_output_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [grad_output_rank, self_rank] + dtypes = [grad_output_dtype, self_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2583,6 +3348,13 @@ def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dty dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function([ + Invocation(TensorOfShape(2, 4, 3, dtype=torch.int32, device="cpu"), k=2, dim=-1, keepdim=False) +]) +def aten〇kthvalue〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, keepdim: bool = False) -> Tuple[int, int]: + _, self_dtype = self_rank_dtype + return (self_dtype, torch.int64) + @check_dtype_function( _check_two_tensor_op(dim=0, input_dtype=torch.float32) + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) @@ -2627,11 +3399,20 @@ def aten〇max_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) -def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[int, int]: +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) +def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2])) +def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) def aten〇adaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype @@ -2710,6 +3491,25 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + noise_rank, noise_dtype = noise_rank_dtype + assert self_rank == noise_rank + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +def aten〇rrelu_with_noise_functional〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + noise_rank, noise_dtype = noise_rank_dtype + assert self_rank == noise_rank + return self_dtype, noise_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2799,6 +3599,12 @@ def aten〇scatter〇value〡dtype(self_rank_dtype: Tuple[int, int], dim: int, i self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(3, dtype=dtype), TensorOfShape(3, dtype=torch.bool), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) def aten〇masked_scatter〡dtype(self_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int]) -> int: @@ -2820,6 +3626,10 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0 self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇as_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + @@ -2828,6 +3638,18 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0 def aten〇_softmax_backward_data〡dtype(grad_output_rank_dtype: Tuple[int, int], output_rank_dtype: Tuple[int, int], dim: int, input_dtype: int) -> int: return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes(), torch.float16, torch.bfloat16})) +def aten〇linalg_slogdet〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int]: + self_rank, self_dtype = A_rank_dtype + assert not is_integer_dtype(self_dtype) + assert self_dtype != torch.float16 and self_dtype != torch.bfloat16 + det_type = self_dtype + if self_dtype == torch.complex32 or self_dtype == torch.complex64: + det_type = torch.float32 + if self_dtype == torch.complex128: + det_type = torch.float64 + return self_dtype, det_type + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇square〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -2921,11 +3743,36 @@ def aten〇upsample_nearest2d_backward〡dtype(grad_output_rank_dtype: Tuple[int grad_output_rank, grad_output_dtype = grad_output_rank_dtype return grad_output_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], output_size=[11])) +def aten〇upsample_nearest1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5)], output_size=[11], scale_factors=None)) +def aten〇upsample_nearest1d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> int: + self_rank, self_dtype = input_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13])) def aten〇upsample_nearest2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], scale_factors=None)) +def aten〇upsample_nearest2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], scale_factors: Optional[List[float]]) -> int: + self_rank, self_dtype = input_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True)) +def aten〇upsample_bilinear2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], align_corners: bool, scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[11, 13], align_corners=True, scale_factors=None)) +def aten〇upsample_bilinear2d〇vec〡dtype(input_rank_dtype: Tuple[int, int], output_size: Optional[List[int]], align_corners: bool, scale_factors: Optional[List[float]]) -> int: + self_rank, self_dtype = input_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, size=[1])) def aten〇view〡dtype(self_rank_dtype: Tuple[int, int], size: List[int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3029,7 +3876,7 @@ def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: T return torch.bool @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)])) -def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int: +def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False) -> int: _, query_dtype = query_rank_dtype return query_dtype @@ -3115,6 +3962,71 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = else: assert False, "Unsupported dtype" +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex32, torch.complex64, torch.complex128, torch.bfloat16})) +def aten〇fft_rfft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + + + +@check_dtype_function([ + Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32 + Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64 + Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=True), # output dtype = torch.complex64 + Invocation(TensorOfShape(1,128, dtype=torch.float32), n_fft=16, return_complex=False), # output dtype = torch.float32 +]) +def aten〇stft〡dtype(self_rank_dtype: Tuple[int, int], n_fft: int, hop_length: Optional[int] = None, win_length: Optional[int] = None, window_rank_dtype: Optional[Tuple[int, int]] = None, normalized: bool = False, onesided: Optional[bool] = None, return_complex: Optional[bool] = None, align_to_window: Optional[bool] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype) and return_complex is not None and return_complex: + return self_dtype + elif is_complex_dtype(self_dtype) and return_complex is not None and return_complex != True: + if self_dtype == torch.complex32: + return torch.float16 + elif self_dtype == torch.complex64: + return torch.float32 + elif self_dtype == torch.complex128: + return torch.float64 + elif is_float_dtype(self_dtype) and return_complex is not None and return_complex: + if self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_float_dtype(self_dtype) and return_complex is not None and return_complex != True: + return self_dtype + elif is_integer_dtype(self_dtype): + return torch.complex64 + + assert False, "Unsupported dtype" + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bfloat16})) +def aten〇fft_ifft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: + self_rank, self_dtype = self_rank_dtype + if is_complex_dtype(self_dtype): + return self_dtype + elif self_dtype == torch.float16: + return torch.complex32 + elif self_dtype == torch.float32: + return torch.complex64 + elif self_dtype == torch.float64: + return torch.complex128 + elif is_integer_dtype(self_dtype): + return torch.complex64 + else: + assert False, "Unsupported dtype" + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) @@ -3122,6 +4034,42 @@ def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ self_rank, self_dtype = self_rank_dtype return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) +def aten〇frac〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇signbit〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.bool + +@check_dtype_function(_check_two_tensor_op()) +def aten〇ldexp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + if self_dtype == torch.double and is_complex_dtype(other_dtype): + return other_dtype + elif is_complex_dtype(self_dtype) and other_dtype == torch.double: + return self_dtype + elif is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype): + return torch.float + else: + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇copysign〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + if is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype): + return torch.float + else: + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) @@ -3303,6 +4251,13 @@ def aten〇div〇Scalar_mode〡dtype(self_rank_dtype: Tuple[int, int], other: Un else: return torch.float32 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,), (4,)])) +def aten〇dot〡dtype(self_rank_dtype: Tuple[int, int], tensor_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = tensor_rank_dtype + self_rank, self_dtype = self_rank_dtype + assert self_dtype == other_dtype + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 4), (2, 4, 3)]) + # Different width @@ -3337,6 +4292,30 @@ def aten〇minimum〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: T dtypes = [self_dtype, other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇fmax〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_two_tensor_op()) +def aten〇fmin〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3,), (4,)])) +def aten〇outer〡dtype(self_rank_dtype: Tuple[int, int], vec2_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + vec2_rank, vec2_dtype = vec2_rank_dtype + ranks: List[Optional[int]] = [self_rank, vec2_rank] + dtypes = [self_dtype, vec2_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4), (4, 3)]) + # Different width @@ -3360,6 +4339,13 @@ def aten〇mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[i dtypes = [self_dtype, mat2_dtype] return promote_dtypes(ranks, dtypes) +def aten〇_int_mm〡dtype(self_rank_dtype: Tuple[int, int], mat2_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + mat2_rank, mat2_dtype = mat2_rank_dtype + assert self_dtype == torch.int8 + assert mat2_dtype == torch.int8 + return torch.int32 + @check_dtype_function(_check_two_tensor_op( output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: @@ -3371,6 +4357,15 @@ def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: assert not is_integer_dtype(promoted_dtype) return promoted_dtype +def aten〇l1_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert not is_integer_dtype(promoted_dtype) + return promoted_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype @@ -3396,7 +4391,7 @@ def aten〇mv〡dtype(self_rank_dtype: Tuple[int, int], vec_rank_dtype: Tuple[in dtypes = [self_dtype, vec_dtype] return promote_dtypes(ranks, dtypes) -@check_dtype_function(_check_two_tensor_op()) +# @check_dtype_function(_check_two_tensor_op()) def aten〇sub〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: other_rank, other_dtype = other_rank_dtype self_rank, self_dtype = self_rank_dtype @@ -3524,6 +4519,10 @@ def aten〇conv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv_transpose1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), @@ -3535,6 +4534,10 @@ def aten〇conv_transpose2d〇input〡dtype(input_rank_dtype: Tuple[int, int], w input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv_transpose3d〇input〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), output_padding: List[int] = (0, 0, 0,), groups: int = 1, dilation: List[int] = (1, 1, 1,)) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + convolution_kwargs = { "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1} @check_dtype_function( @@ -3605,6 +4608,16 @@ def aten〇bincount〡dtype(self_rank_dtype: Tuple[int, int], weights_rank_dtype return torch.int64 return torch.float64 +@check_dtype_function([ + Invocation(TensorOfShape(1, 5, 5, dtype=torch.int64), [5,5], [1,5], [1,1], [0,0], [1,1]), # int type + Invocation(TensorOfShape(1, 5, 5, dtype=torch.float64), [5,5], [1,5], [1,1], [0,0], [1,1]), # float type + Invocation(TensorOfShape(1, 5, 5, dtype=torch.complex64), [5,5], [1,5], [1,1], [0,0], [1,1]), # complex type + Invocation(TensorOfShape(1, 5, 5, dtype=torch.bool), [5,5], [1,5], [1,1], [0,0], [1,1]), # boolean type +]) +def aten〇col2im〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int], kernel_size: List[int], dilation: List[int], padding: List[int], stride: List[int]) -> int: + _, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device=torch.device("cpu"))) def aten〇nonzero〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int64 @@ -3636,18 +4649,7 @@ def aten〇addmm〡dtype(self_rank_dtype: Tuple[int, int], mat1_rank_dtype: Tupl return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + - # Different width - [Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float64), - TensorOfShape(4, 3, dtype=torch.float32)), - # Different type - Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.int32)), - Invocation(TensorOfShape(4, 3, dtype=torch.int32), - TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float32))]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)])) def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype end_rank, end_dtype = end_rank_dtype @@ -3658,28 +4660,17 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) + - # Different width - [Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float64), - weight=0.5), - # Different type - Invocation(TensorOfShape(4, 3, dtype=torch.int32), - TensorOfShape(4, 3, dtype=torch.float32), - weight=0.5), - Invocation(TensorOfShape(4, 3, dtype=torch.float32), - TensorOfShape(4, 3, dtype=torch.float32), - weight=2)]) + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5)) def aten〇lerp〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int: self_rank, self_dtype = self_rank_dtype end_rank, end_dtype = end_rank_dtype - ranks: List[Optional[int]] = [self_rank, end_rank, None] - dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)] + ranks: List[Optional[int]] = [self_rank, end_rank] + dtypes = [self_dtype, end_dtype] return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float64), @@ -3696,16 +4687,11 @@ def aten〇addcmul〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: tensor1_rank, tensor1_dtype = tensor1_rank_dtype tensor2_rank, tensor2_dtype = tensor2_rank_dtype - assert self_dtype != torch.bool - assert tensor1_dtype != torch.bool - assert tensor2_dtype != torch.bool - ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] return promote_dtypes(ranks, dtypes) @check_dtype_function( - _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)]) + # Different width [Invocation(TensorOfShape(3, 3, dtype=torch.float32), TensorOfShape(3, 3, dtype=torch.float64), @@ -3725,8 +4711,6 @@ def aten〇addcdiv〡dtype(self_rank_dtype: Tuple[int, int], tensor1_rank_dtype: ranks: List[Optional[int]] = [self_rank, tensor1_rank, tensor2_rank] dtypes = [self_dtype, tensor1_dtype, tensor2_dtype] result = promote_dtypes(ranks, dtypes) - if is_integer_dtype(result): - return torch.float32 return result @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + @@ -4151,10 +5135,19 @@ def aten〇amax〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), k def aten〇max〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇max〡dtype(self_rank_dtype), torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇amin〡dtype(self_rank_dtype: Tuple[int, int], dim: List[int] = (), keepdim: bool = False) -> int: + return aten〇min〡dtype(self_rank_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) def aten〇min〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> Tuple[int, int]: return aten〇min〡dtype(self_rank_dtype), torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇aminmax〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None, keepdim: bool = False) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, @@ -4195,7 +5188,7 @@ def aten〇var〇correction〡dtype(self_rank_dtype: Tuple[int, int], dim: Optio return aten〇std〡dtype(self_rank_dtype) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[], correction=0.0)) -def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: float, output_dtype: Optional[int] = None) -> int: +def prims〇var〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], correction: Optional[float] = 1, output_dtype: Optional[int] = None) -> int: return aten〇std〡dtype(inp_rank_dtype) @check_dtype_function( @@ -4244,6 +5237,24 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U return dtype return aten〇std〡dtype(self_rank_dtype) +def aten〇binary_cross_entropy_with_logits〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]] = None, pos_weight_rank_dtype: Optional[Tuple[int, int]] = None, reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype( + tensor_shapes=[(3,3)], + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}, + p=1, + dim=0, + maxnorm=5) +) +def aten〇renorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, @@ -4424,6 +5435,10 @@ def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇rot90〡dtype(self_rank_dtype: Tuple[int, int], k: int = 1, dims: List[int] = (0, 1,)) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) + @@ -4597,7 +5612,7 @@ def aten〇atanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype -@check_dtype_function(_check_two_tensor_op()) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype weight_rank, weight_dtype = weight_rank_dtype @@ -4606,6 +5621,21 @@ def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(3, None, None, None, expand1 = [], expand2 = [], expand3 = [], sumdim = [], unroll_dim = 0), +) +def aten〇_trilinear〡dtype(i1_rank_dtype: Tuple[int, int], i2_rank_dtype: Tuple[int, int], i3_rank_dtype: Tuple[int, int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> int: + i1_rank, i1_dtype = i1_rank_dtype + i2_rank, i2_dtype = i2_rank_dtype + i3_rank, i3_dtype = i3_rank_dtype + + ranks: List[Optional[int]] = [i1_rank, i2_rank, i3_rank] + dtypes = [i1_dtype, i2_dtype, i3_dtype] + return promote_dtypes( + ranks, + dtypes, + ) + @check_dtype_function( [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), @@ -4621,6 +5651,50 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) dtypes.append(tensor_dtype) return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atleast_1d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atleast_2d〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + + return promote_dtypes(ranks, dtypes) + +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇column_stack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + + return promote_dtypes(ranks, dtypes) + @check_dtype_function( [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) @@ -4663,8 +5737,7 @@ def aten〇ScalarImplicit〡dtype(a_rank_dtype: Tuple[int, int]) -> int: def prim〇NumToTensor〇Scalar〡dtype(a: Union[int, float, complex]) -> int: return get_dtype_of_scalar(a) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + - _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: @@ -4674,7 +5747,7 @@ def aten〇softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dty return dtype @check_dtype_function( - _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + # _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), @@ -4686,8 +5759,14 @@ def aten〇_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half_to_ return torch.float32 return self_dtype +def aten〇_safe_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( - _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + + # _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, half_to_float=False) + _check_tensors_with_the_same_dtype( num_of_tensors=1, error_types=(all_integer_dtypes() + all_complex_dtypes() + [torch.bfloat16, torch.float32, torch.float64]), @@ -4699,8 +5778,7 @@ def aten〇_log_softmax〡dtype(self_rank_dtype: Tuple[int, int], dim: int, half return torch.float32 return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + - _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.int32) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float16) + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.complex64)) def aten〇log_softmax〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: @@ -4754,6 +5832,16 @@ def aten〇dequantize〇self〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> int: return torch.float32 +def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.int64 if dtype is None else dtype + +def aten〇tril_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + return torch.int64 if dtype is None else dtype + +def aten〇deg2rad〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.quint8): @@ -4778,7 +5866,45 @@ def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, return torch.qint8 return torch.qint32 +@check_shape_function([ + Invocation(TensorOfShape(), 0, 1, 1), # Rank Zero. + Invocation(TensorOfShape(), 0, 0, 1), # Rank Zero, size of 0. + Invocation(TensorOfShape(6, 4), 0, 2, 1), # Basic case. + Invocation(TensorOfShape(6, 4, 2), 0, 2, 1), # Basic case. + Invocation(TensorOfShape(6, 4), -1, 2, 1), # Negative Dimension. + Invocation(TensorOfShape(6, 4, 2), -1, 2, 1), # Negative Dimension. +]) +def aten〇unfold〡shape(self: List[int], dimension: int, size: int, step: int) -> List[int]: + ndim = len(self) + + # Rank zero tensor + if ndim == 0: + assert dimension == 0, f"dimension out of range of {ndim}" + assert size <= 1, "size must be less than or equal to 1" + return [size] + + dim = dimension + if dim < 0: + dim += ndim + + assert (dim >= 0 and dim < ndim), f"dimension out of range of {ndim}" + + size_dim = self[dim] + assert size <= size_dim, f"size must be less than or equal to {size_dim}" + num_blocks = (size_dim - size) // step + 1 + + out = upstream_shape_functions._copy(self) + out[dim] = num_blocks + out.append(size) + return out + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dimension=0, size=1, step=1) +) +def aten〇unfold〡dtype(self_rank_dtype: Tuple[int, int], dimension: int, size: int, step: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype @@ -4797,6 +5923,9 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args): _maybe_import_op_extensions(args) + # importing torchvision will register torchvision ops with the JITOperatorRegistry + import torchvision + asm = generate_library(globals()) # We're about to put quotes around the string, so escape the `"` characters. asm = asm.replace("\"", "\\\"") diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c847e42d844a..ca7c24becad9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -301,6 +301,9 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", + "aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)", + "aten::celu : (Tensor, Scalar) -> (Tensor)", "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sinh : (Tensor) -> (Tensor)", @@ -314,6 +317,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::asin : (Tensor) -> (Tensor)", "aten::asinh : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", + "aten::exp2 : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", "aten::cosh : (Tensor) -> (Tensor)", @@ -325,6 +329,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::atanh : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", + "aten::frac : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", @@ -366,6 +371,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::zero : (Tensor) -> (Tensor)", "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)", + "aten::copysign.Tensor : (Tensor, Tensor) -> (Tensor)", ]: emit_with_mutating_variants(key) # Shape manipulations: @@ -377,6 +383,7 @@ def emit_with_mutating_variants(key, **kwargs): # variants. emit_with_mutating_variants( "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", + has_folder=True, has_canonicalizer=True, ) emit_with_mutating_variants( @@ -413,6 +420,12 @@ def emit_with_mutating_variants(key, **kwargs): has_canonicalizer=True, has_folder=True, ) + emit( + "aten::ldexp.Tensor : (Tensor, Tensor) -> (Tensor)", + ) + emit( + "aten::signbit : (Tensor) -> (Tensor)", + ) emit_with_mutating_variants( "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True ) @@ -439,6 +452,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) + emit("aten::special_expm1 : (Tensor) -> (Tensor)") emit_with_mutating_variants( "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True ) @@ -456,23 +470,44 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)" ) + emit( + "aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)" + ) + emit( + "aten::fake_quantize_per_tensor_affine.tensor_qparams : (Tensor, Tensor, Tensor, int, int) -> (Tensor)" + ) + emit( + "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams : (Tensor, Tensor, Tensor, Tensor, int, int) -> (Tensor, Tensor)" + ) + emit( + "aten::fake_quantize_per_channel_affine : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor)" + ) + emit( + "aten::fake_quantize_per_channel_affine_cachemask : (Tensor, Tensor, Tensor, int, int, int) -> (Tensor, Tensor)" + ) + emit("aten::isfinite : (Tensor) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") + emit("aten::fmax : (Tensor, Tensor) -> (Tensor)") + emit("aten::fmin : (Tensor, Tensor) -> (Tensor)") emit("aten::mish : (Tensor) -> (Tensor)") emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") emit( "aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", + has_folder=True, has_canonicalizer=True, ) emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::pow.Scalar : (Scalar, Tensor) -> (Tensor)") + emit("aten::float_power.Tensor_Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)") emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::prelu : (Tensor, Tensor) -> (Tensor)") - emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)") + emit("aten::rad2deg : (Tensor) -> (Tensor)") + emit("aten::complex : (Tensor, Tensor) -> (Tensor)") emit("aten::real : (Tensor) -> (Tensor)") emit("aten::imag : (Tensor) -> (Tensor)") emit("aten::view_as_complex : (Tensor) -> (Tensor)") @@ -482,6 +517,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::log_sigmoid : (Tensor) -> (Tensor)") emit("aten::hardshrink : (Tensor, Scalar) -> (Tensor)") emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)") + emit("aten::polar : (Tensor, Tensor) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") @@ -529,19 +565,31 @@ def emit_with_mutating_variants(key, **kwargs): # Non-elementwise tensor compute ops emit("aten::linear : (Tensor, Tensor, Tensor?) -> (Tensor)") emit("aten::mm : (Tensor, Tensor) -> (Tensor)") + emit("aten::_int_mm : (Tensor, Tensor) -> (Tensor)") emit("aten::addmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)") + emit("aten::dot : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit("aten::outer : (Tensor, Tensor) -> (Tensor)") emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") emit( "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" + ) emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" + ) emit( "aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" + ) emit( "aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)" ) @@ -585,6 +633,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True) emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)") emit( @@ -594,7 +643,11 @@ def emit_with_mutating_variants(key, **kwargs): "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit( + "aten::max_pool1d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + ) emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", has_canonicalizer=True, @@ -603,8 +656,10 @@ def emit_with_mutating_variants(key, **kwargs): "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( - "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" + "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", + has_canonicalizer=True, ) emit( "aten::max_pool3d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" @@ -639,7 +694,10 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") - emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") + emit( + "aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)", + has_canonicalizer=True, + ) emit("aten::_adaptive_avg_pool2d_backward : (Tensor, Tensor) -> (Tensor)") emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") @@ -648,7 +706,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") - emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") + emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True) emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") @@ -661,7 +719,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) + emit("aten::__lshift__.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::__rshift__.Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)") + emit("aten::_safe_softmax : (Tensor, int, int?) -> (Tensor)") emit("aten::mean : (Tensor, int?) -> (Tensor)") emit("aten::std : (Tensor, bool) -> (Tensor)") emit("aten::std.dim : (Tensor, int[]?, bool, bool) -> (Tensor)") @@ -690,9 +751,13 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") + emit("aten::linalg_det : (Tensor) -> (Tensor)") + emit("aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)") + emit("aten::linalg_slogdet : (Tensor) -> (Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") + emit("aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)") emit( "aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)" ) @@ -706,6 +771,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)" ) + emit( + "aten::binary_cross_entropy_with_logits : (Tensor, Tensor, Tensor?, Tensor?, int) -> (Tensor)" + ) emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") @@ -713,17 +781,24 @@ def emit_with_mutating_variants(key, **kwargs): "aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)" ) emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)") + emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)") + emit("aten::rot90 : (Tensor, int, int[]) -> (Tensor)", has_verifier=True) # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) - emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") - emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") + emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)", has_folder=True) + emit( + "aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)", + has_canonicalizer=True, + has_folder=True, + ) emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") @@ -763,6 +838,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") + emit("aten::atleast_1d : (Tensor) -> (Tensor)") + emit("aten::atleast_2d : (Tensor) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::trace : (Tensor) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") @@ -809,7 +886,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::repeat : (Tensor, int[]) -> (Tensor)") emit("aten::repeat_interleave.self_int : (Tensor, int, int?, int?) -> (Tensor)") emit("aten::tile : (Tensor, int[]) -> (Tensor)") - emit("aten::reshape : (Tensor, int[]) -> (Tensor)") + emit("aten::reshape : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::reshape_as : (Tensor, Tensor) -> (Tensor)") emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize : (Tensor, int[], int?) -> (Tensor)") @@ -828,6 +905,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::min.other : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") + emit("aten::aminmax : (Tensor, int?, bool) -> (Tensor, Tensor)") emit( "aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True ) @@ -846,6 +924,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_cast_Long : (Tensor, bool) -> (Tensor)", has_canonicalizer=True) emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) + emit("aten::view.dtype : (Tensor, int) -> (Tensor)") emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True) emit( @@ -879,7 +958,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)", has_folder=True ) - emit("aten::Int.Tensor : (Tensor) -> (int)", has_canonicalizer=True) + emit( + "aten::Int.Tensor : (Tensor) -> (int)", has_folder=True, has_canonicalizer=True + ) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)") emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") @@ -898,20 +979,41 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants( "aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)" ) + emit( + "aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)" + ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit( "aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)" ) + emit( + "aten::unique_dim : (Tensor, int, bool, bool, bool) -> (Tensor, Tensor, Tensor)" + ) emit( "aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)" ) emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) + emit("aten::col2im : (Tensor, int[], int[], int[], int[], int[]) -> (Tensor)") + emit( + "aten::kthvalue : (Tensor, int, int, bool) -> (Tensor, Tensor)", + has_verifier=True, + ) + emit( + "aten::stft : (Tensor, int, int?, int?, Tensor?, bool, bool?, bool?, bool?) -> (Tensor)" + ) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True) emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)") + emit( + "aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()", + has_folder=True, + ) emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") @@ -927,6 +1029,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::unsqueeze_copy : (Tensor, int) -> (Tensor)") emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") + emit("aten::unfold : (Tensor, int, int, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)") emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)") @@ -934,11 +1037,21 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::upsample_nearest1d : (Tensor, int[], float?) -> (Tensor)") + emit("aten::upsample_nearest1d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") + emit("aten::upsample_nearest2d.vec : (Tensor, int[]?, float[]?) -> (Tensor)") emit( - "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)" + "aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)" + ) + emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)") + emit( + "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" ) emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") + emit( + "aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)" + ) # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) @@ -956,6 +1069,8 @@ def emit_with_mutating_variants(key, **kwargs): has_folder=True, ) emit("aten::stack : (Tensor[], int) -> (Tensor)") + emit("aten::hstack : (Tensor[]) -> (Tensor)") + emit("aten::column_stack : (Tensor[]) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) @@ -968,9 +1083,14 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True) emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") - emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])") + emit( + "aten::split.sizes : (Tensor, int[], int) -> (Tensor[])", has_canonicalizer=True + ) + emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") + emit("aten::meshgrid : (Tensor[]) -> (Tensor[])", has_canonicalizer=True) + emit("aten::meshgrid.indexing : (Tensor[], str) -> (Tensor[])") # Str ops. emit("aten::add.str : (str, str) -> (str)") @@ -999,17 +1119,27 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::le.int : (int, int) -> (bool)", has_folder=True) emit("aten::ne.int : (int, int) -> (bool)", has_folder=True) emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) - emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) + emit( + "aten::floordiv.int : (int, int) -> (int)", + has_folder=True, + has_canonicalizer=True, + ) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) - emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) - emit("aten::mul.int : (int, int) -> (int)", has_folder=True) + emit( + "aten::mul.int : (int, int) -> (int)", + has_folder=True, + has_canonicalizer=True, + ) + emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) + emit("aten::mul.float_int : (float, int) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) @@ -1024,11 +1154,14 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::gt.float_int : (float, int) -> (bool)") emit("aten::pow.int_float : (int, float) -> (float)", has_folder=True) emit("aten::__and__.bool : (bool, bool) -> (bool)") + emit("aten::eq.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True) + emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) + emit("aten::mul.left_t : (t[], int) -> (t[])", has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True) @@ -1046,6 +1179,18 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::narrow.Tensor : (Tensor, int, Tensor, int) -> (Tensor)") emit("aten::ScalarImplicit : (Tensor) -> (Scalar)", has_canonicalizer=True) + emit( + "aten::triu_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)", + has_verifier=True, + ) + + emit( + "aten::tril_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)", + has_verifier=True, + ) + + emit("aten::deg2rad : (Tensor) -> (Tensor)") + # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") @@ -1069,6 +1214,12 @@ def emit_with_mutating_variants(key, **kwargs): "aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)" ) emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + emit( + "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" + ) + emit( + "aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)" + ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") @@ -1081,6 +1232,11 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") + # Constraint ops + emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()") + emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()") + emit("aten::_assert_scalar : (Scalar, str) -> ()") + # ========================================================================== # `prim::` namespace. # ========================================================================== @@ -1111,11 +1267,12 @@ def emit_with_mutating_variants(key, **kwargs): # ========================================================================== emit("prims::convert_element_type : (Tensor, int) -> (Tensor)", has_folder=True) - emit("prims::var : (Tensor, int[]?, float, int?) -> (Tensor)") + emit("prims::var : (Tensor, int[]?, float?, int?) -> (Tensor)") emit("prims::sqrt : (Tensor) -> (Tensor)") emit("prims::collapse : (Tensor, int, int) -> (Tensor)") emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") + emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)") @@ -1128,6 +1285,21 @@ def emit_with_mutating_variants(key, **kwargs): traits=["HasValueSemantics"], ) + # ========================================================================== + # `torchvision::` namespace. + # ========================================================================== + + emit( + "torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)" + ) + emit( + "torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)" + ) + emit( + "torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)" + ) + emit("torchvision::nms : (Tensor, Tensor, float) -> (Tensor)") + def dump_registered_ops(outfile: TextIO, registry: Registry): for _, v in sorted(registry.by_unique_key.items()): @@ -1146,6 +1318,9 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args: argparse.Namespace): _maybe_import_op_extensions(args) + # importing torchvision will register torchvision ops with the JITOperatorRegistry + import torchvision + registry = Registry.load() if args.debug_registry_dump: with open(args.debug_registry_dump, "w") as debug_registry_dump: diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index ef224776fde2..c6cf625e4fe1 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -14,66 +14,17 @@ from torch._functorch.compile_utils import strip_overloads import torch import torch.fx -from torch_mlir.dynamo import _get_decomposition_table -from torch.fx.experimental.proxy_tensor import make_fx from torch_mlir.compiler_utils import ( run_pipeline_with_repro_report, OutputType, lower_mlir_module, + TensorPlaceholder, ) from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library -class TensorPlaceholder: - """A class that represents a formal parameter of a given shape and dtype. - - This class can be constructed explicitly from a shape and dtype: - ```python - placeholder = TensorPlaceholder([3, 4], torch.float32) - ``` - - This class can also be constructed from a `torch.Tensor` which is already - known to be a valid input to the function. In this case, a set of - dynamic axes are allowed to be specified. - ```python - placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1]) - # Equivalent to `TensorPlaceholder([3, -1], torch.float32)` - ``` - """ - - def __init__(self, shape: List[int], dtype: torch.dtype): - """Create a tensor with shape `shape` and dtype `dtype`. - - Args: - shape: The shape of the tensor. A size of `-1` indicates that the - dimension has an unknown size. - dtype: The dtype of the tensor. - """ - self.shape = shape - self.dtype = dtype - - @staticmethod - def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): - """Create a tensor placeholder that is like the given tensor. - - Args: - tensor: The tensor to create a placeholder for. - dynamic_axes: A list of dynamic axes. If specified, the compiled - module will allow those axes to be any size at runtime. - """ - if dynamic_axes is None: - dynamic_axes = [] - shape = [] - for i, dim in enumerate(tensor.shape): - if i in dynamic_axes: - shape.append(-1) - else: - shape.append(dim) - return TensorPlaceholder(shape, tensor.dtype) - - _example_arg = Union[TensorPlaceholder, torch.Tensor] _example_args_for_one_method = Union[_example_arg, Sequence[_example_arg]] _example_args = Union[_example_args_for_one_method, "ExampleArgs"] @@ -212,7 +163,13 @@ def _get_for_tracing( "aten.adaptive_avg_pool2d", "aten.unflatten.int", ], - OutputType.STABLEHLO: [], + OutputType.STABLEHLO: [ + "aten.amax", + "aten.amin", + "aten.randn.generator", + "aten.normal_functional", + "aten.fmod.Tensor", + ], } @@ -244,7 +201,6 @@ def compile( backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], verbose: bool = False, - use_make_fx: bool = False, enable_ir_printing: bool = False, ): """Convert a PyTorch model to MLIR. @@ -307,12 +263,6 @@ def compile( else: backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) - if use_make_fx: - args = example_args._get_for_tracing( - use_tracing=True, ignore_traced_shapes=True - )["forward"] - model = make_fx(model, decomposition_table=_get_decomposition_table())(*args) - # For FX-based models, automatically strip overloads. if isinstance(model, torch.fx.GraphModule): strip_overloads(model) @@ -377,6 +327,12 @@ def compile( ) from None finally: sys.stderr = original_stderr + + if verbose: + print("\n====================") + print("TorchScript RAW IR") + print(mb.module) + if output_type == OutputType.RAW: return mb.module diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 2a63c06bdc37..91bc49ebb893 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -8,6 +8,7 @@ import torch.utils._pytree as pytree from torch.export.graph_signature import OutputSpec, OutputKind from torch.export import ExportedProgram +from torch._dynamo.backends.common import aot_autograd from torch_mlir import fx from torch_mlir_e2e_test.configs.utils import ( @@ -15,6 +16,7 @@ recursively_convert_from_numpy, ) from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME def refine_result_type(_result): @@ -31,15 +33,91 @@ def refine_result_type(_result): class FxImporterTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with Fx Importer""" - def __init__(self, backend, output_type="linalg-on-tensors"): + def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False): super().__init__() self._backend = backend + self._torch_compile = torch_compile self._output_type = output_type - def compile(self, program: torch.nn.Module) -> torch.nn.Module: + def compile( + self, program: torch.nn.Module, verbose: bool = False + ) -> torch.nn.Module: return program - def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: + def run(self, artifact: torch.nn.Module, trace: Trace): + return ( + self._export_run(artifact, trace) + if not self._torch_compile + else self._stateless_run(artifact, trace) + ) + + def _stateless_run(self, artifact: torch.nn.Module, trace: Trace): + dynamic_argument_pos = None + dynamic_dim_pos = None + annotations = getattr(artifact.forward, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) + for i, annotation in enumerate(annotations): + if i == 0: # Skip the "self" annotation. + continue + if not annotation[2]: + raise ValueError( + "Can only compile inputs annotated as having value semantics." + ) + for dim_i, dim in enumerate(annotation[0]): + if dim == -1: + dynamic_argument_pos = i - 1 + dynamic_dim_pos = dim_i + break + if dynamic_argument_pos is not None: + break + result: Trace = [] + for item in trace: + + def _base_backend(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.op == "placeholder": + if ( + isinstance(node.meta["val"], torch.SymInt) + and not node.users + ): + gm.graph.erase_node(node) + module = fx.stateless_fx_import( + gm, + output_type=self._output_type, + model_name=artifact.__class__.__name__, + ) + module = self._backend.compile(module) + backend_module = self._backend.load(module) + + def invoke_func(*torch_inputs): + torch_inputs = [ + x + for x in filter( + lambda i: isinstance(i, torch.Tensor), torch_inputs + ) + ] + with torch.no_grad(): + numpy_inputs = recursively_convert_to_numpy(torch_inputs) + return recursively_convert_from_numpy( + getattr(backend_module, artifact.__class__.__name__)( + *numpy_inputs + ) + ) + + return invoke_func + + fw_compiler = aot_autograd(fw_compiler=_base_backend) + if dynamic_argument_pos is not None: + torch._dynamo.mark_dynamic( + item.inputs[dynamic_argument_pos], dynamic_dim_pos + ) + module = torch.compile(artifact, backend=fw_compiler) + outputs = module(*item.inputs) + result.append( + TraceItem(symbol=item.symbol, inputs=item.inputs, output=outputs) + ) + return result + + def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs)) @@ -47,6 +125,9 @@ def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: prog, output_type=self._output_type, func_name=artifact.__class__.__name__, + # While the current e2e tests don't exercise symbolic shapes, + # enabling this here ensures they don't regress either. + import_symbolic_shape_expressions=True, ) module = self._backend.compile(module) backend_module = self._backend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py b/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py index 4f2d9ec90221..04fb523d5ea5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/lazy_tensor_core.py @@ -22,7 +22,9 @@ def __init__(self): super().__init__() lazy_backend._initialize() - def compile(self, program: torch.nn.Module) -> torch.nn.Module: + def compile( + self, program: torch.nn.Module, verbose: bool = False + ) -> torch.nn.Module: return program.to("lazy") def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py index bbc6e73ee770..059d43c55bfa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py @@ -29,10 +29,10 @@ def __init__(self, backend: LinalgOnTensorsBackend): super().__init__() self.backend = backend - def compile(self, program: torch.nn.Module) -> Any: + def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torchscript.compile( - program, example_args, output_type="linalg-on-tensors" + program, example_args, output_type="linalg-on-tensors", verbose=verbose ) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py b/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py index e7907cd14251..7ab251f02ae9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/native_torch.py @@ -14,7 +14,9 @@ class NativeTorchTestConfig(TestConfig): def __init__(self): super().__init__() - def compile(self, program: torch.nn.Module) -> torch.nn.Module: + def compile( + self, program: torch.nn.Module, verbose: bool = False + ) -> torch.nn.Module: return program def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 5402c7243e00..5461dc04c0d1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -9,9 +9,9 @@ import io import onnx import torch +from torch.onnx._constants import ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET as max_opset_ver import torch_mlir -from torch_mlir_e2e_test.onnx_backends.abc import OnnxBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( @@ -22,13 +22,27 @@ from torch_mlir.extras import onnx_importer from torch_mlir.dialects import torch as torch_d from torch_mlir.ir import Context, Module +from torch_mlir.compiler_utils import ( + OutputType, + run_pipeline_with_repro_report, + lower_mlir_module, +) + +# The pipeline of func.func passes that lower the ONNX backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend or another user +# defined backend. +ONNX_TO_TORCH_FUNC_PIPELINE = ",".join( + [ + "convert-torch-onnx-to-torch", + ] +) def import_onnx(contents): # Import the ONNX model proto from the file contents: raw_model = onnx.load_from_string(contents) # since it does not affect current e2e tests, data_prop is left false here - model_proto = onnx.shape_inference.infer_shapes(raw_model) + model_proto = onnx.shape_inference.infer_shapes(raw_model, data_prop=True) # Import the ONNX module into an MLIR module: context = Context() @@ -65,12 +79,49 @@ def convert_onnx(model, inputs): examples = tuple(examples) torch.onnx.export( - model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors + model, + examples, + buffer, + input_names=input_names, + dynamic_axes=dynamic_tensors, + opset_version=max_opset_ver, ) buffer = buffer.getvalue() return import_onnx(buffer) +def _module_lowering( + verbose, + output_type, + torch_mod, +): + if verbose: + print("\n====================") + print("ONNX RAW IR") + print(torch_mod) + + backend_legal_ops = [ + "aten.flatten.using_ints", + "aten.adaptive_avg_pool1d", + "aten.unflatten.int", + ] + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + + # Lower from ONNX to Torch + run_pipeline_with_repro_report( + torch_mod, + f"builtin.module(torch-onnx-to-torch-backend-pipeline{option_string})", + "Lowering Onnx Raw IR -> Torch Backend IR", + ) + + if verbose: + print("\n====================") + print("Torch IR") + print(torch_mod) + + return lower_mlir_module(verbose, output_type, torch_mod) + + class OnnxBackendTestConfig(TestConfig): """Base class for TestConfig's that are implemented with ONNX. @@ -78,15 +129,22 @@ class OnnxBackendTestConfig(TestConfig): reaching the ONNX abstraction level. """ - def __init__(self, backend: OnnxBackend, use_make_fx: bool = False): + def __init__( + self, + backend, + output_type="linalg-on-tensors", + ): super().__init__() self.backend = backend - self.use_make_fx = use_make_fx + self.output_type = output_type - def compile(self, program: torch.nn.Module) -> Any: + def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any: example_args = convert_annotations_to_placeholders(program.forward) onnx_module = convert_onnx(program, example_args) - compiled_module = self.backend.compile(onnx_module) + backend_module = _module_lowering( + verbose, OutputType.get(self.output_type), onnx_module + ) + compiled_module = self.backend.compile(backend_module) return compiled_module def run(self, artifact: Any, trace: Trace) -> Trace: diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py index 1ab8a8d22b4f..5e764855ec08 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py @@ -28,9 +28,11 @@ def __init__(self, backend: StablehloBackend): super().__init__() self.backend = backend - def compile(self, program: torch.nn.Module) -> Any: + def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torchscript.compile(program, example_args, output_type="stablehlo") + module = torchscript.compile( + program, example_args, output_type="stablehlo", verbose=verbose + ) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index 54dc7d3f98ff..fcea6d87de6f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -170,7 +170,9 @@ def __init__(self, backend): super().__init__() self.backend = backend - def compile(self, program: torch.nn.Module) -> torch.nn.Module: + def compile( + self, program: torch.nn.Module, verbose: bool = False + ) -> torch.nn.Module: return program def run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py index a40e06f01248..7057a01a735a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchscript.py @@ -17,7 +17,9 @@ class TorchScriptTestConfig(TestConfig): def __init__(self): super().__init__() - def compile(self, program: torch.nn.Module) -> torch.jit.ScriptModule: + def compile( + self, program: torch.nn.Module, verbose: bool = False + ) -> torch.jit.ScriptModule: return torch.jit.script(program) def run(self, artifact: torch.jit.ScriptModule, trace: Trace) -> Trace: diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py index 1b5c86bb64d4..2601b2b6a4d8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -24,15 +24,17 @@ class TosaBackendTestConfig(TestConfig): reaching the TOSA abstraction level. """ - def __init__(self, backend: TosaBackend, use_make_fx: bool = False): + def __init__(self, backend: TosaBackend): super().__init__() self.backend = backend - self.use_make_fx = use_make_fx - def compile(self, program: torch.nn.Module) -> Any: + def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torchscript.compile( - program, example_args, output_type="tosa", use_make_fx=self.use_make_fx + program, + example_args, + output_type="tosa", + verbose=verbose, ) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index ee438cbbb167..c24af96f3e0e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -27,6 +27,7 @@ import os import sys import traceback +import signal import multiprocess as mp from multiprocess import set_start_method @@ -230,6 +231,7 @@ class Test(NamedTuple): # module, actually). # The secon parameter is a `TestUtils` instance for convenience. program_invoker: Callable[[Any, TestUtils], None] + timeout_seconds: int class TestResult(NamedTuple): @@ -305,43 +307,79 @@ def generate_golden_trace(test: Test) -> Trace: return trace +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: - try: - golden_trace = generate_golden_trace(test) - if verbose: - print(f"Compiling {test.unique_name}...", file=sys.stderr) - compiled = config.compile(test.program_factory()) - except Exception as e: - return TestResult( - unique_name=test.unique_name, - compilation_error="".join( - traceback.format_exception(type(e), e, e.__traceback__) - ), - runtime_error=None, - trace=None, - golden_trace=None, - ) - try: - if verbose: - print(f"Running {test.unique_name}...", file=sys.stderr) - trace = config.run(compiled, golden_trace) - except Exception as e: + with timeout(seconds=test.timeout_seconds): + try: + golden_trace = generate_golden_trace(test) + if verbose: + print(f"Compiling {test.unique_name}...", file=sys.stderr) + compiled = config.compile(test.program_factory(), verbose=verbose) + except TimeoutError: + return TestResult( + unique_name=test.unique_name, + compilation_error=f"Test timed out during compilation (timeout={test.timeout_seconds}s)", + runtime_error=None, + trace=None, + golden_trace=None, + ) + except Exception as e: + return TestResult( + unique_name=test.unique_name, + compilation_error="".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + runtime_error=None, + trace=None, + golden_trace=None, + ) + try: + if verbose: + print(f"Running {test.unique_name}...", file=sys.stderr) + trace = config.run(compiled, golden_trace) + + # Disable the alarm + signal.alarm(0) + except TimeoutError: + return TestResult( + unique_name=test.unique_name, + compilation_error=None, + runtime_error="Test timed out during execution (timeout={test.timeout}s)", + trace=None, + golden_trace=None, + ) + except Exception as e: + return TestResult( + unique_name=test.unique_name, + compilation_error=None, + runtime_error="".join( + traceback.format_exception(type(e), e, e.__traceback__) + ), + trace=None, + golden_trace=None, + ) return TestResult( unique_name=test.unique_name, compilation_error=None, - runtime_error="".join( - traceback.format_exception(type(e), e, e.__traceback__) - ), - trace=None, - golden_trace=None, + runtime_error=None, + trace=clone_trace(trace), + golden_trace=clone_trace(golden_trace), ) - return TestResult( - unique_name=test.unique_name, - compilation_error=None, - runtime_error=None, - trace=clone_trace(trace), - golden_trace=clone_trace(golden_trace), - ) def run_tests( @@ -358,6 +396,15 @@ def run_tests( if env_concurrency > 0: num_processes = min(num_processes, env_concurrency) + try: + env_verbose = os.getenv("TORCH_MLIR_TEST_VERBOSE", "0") + if env_verbose is not None: + verbose = verbose or bool(int(env_verbose)) + except ValueError as e: + raise ValueError( + "Bad value for TORCH_MLIR_TEST_VERBOSE env var: " "Expected integer." + ) from e + # TODO: We've noticed that on certain 2 core machine parallelizing the tests # makes the llvm backend legacy pass manager 20x slower than using a # single process. Need to investigate the root cause eventually. This is a @@ -375,7 +422,10 @@ def run_tests( # seems to cause a cascade of failures resulting in undecipherable error # messages. if num_processes == 1 or sequential: - return [compile_and_run_test(test, config, verbose) for test in tests] + print("Running tests sequentially with progress status") + for test in tests: + print(f"*** RUNNING TEST: {test.unique_name} ***") + compile_and_run_test(test, config, verbose) # This is needed because autograd does not support crossing process # boundaries. @@ -383,8 +433,15 @@ def run_tests( pool = mp.Pool(num_processes) arg_list = zip(tests, repeat(config)) + pool_copy = pool._pool[:] handles = pool.starmap_async(compile_and_run_test, arg_list) - results = handles.get(timeout=360) + while not handles.ready(): + if any(proc.exitcode for proc in pool_copy): + print("At least one of testing processes has exited with code != 0.") + exit(1) + handles.wait(timeout=1) + else: + results = handles.get(timeout=360) tests_with_results = {result.unique_name for result in results} all_tests = {test.unique_name for test in tests} diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 8935a2a060fd..7db53b8ca702 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -134,85 +134,84 @@ def invoke(*args): return invoke -LOWERING_PIPELINE = ( - "builtin.module(" - + ",".join( - [ - "func.func(refback-generalize-tensor-pad)", - "func.func(refback-generalize-tensor-concat)", - # Apply some optimizations. It would be great if MLIR had more useful - # optimizations that worked out of the box here. - # Note: When measured, this doesn't seem to actually help that much - # for the linalg-on-tensors backend. - # This is likely because if things are naturally fusable we usually already - # emit things in that form from the high level (e.g. single linalg-generic). - # Other backends are likely to benefit more. - "func.func(linalg-generalize-named-ops)", - "func.func(linalg-fuse-elementwise-ops)", - "convert-shape-to-std", - # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum - # to ensure operations on sparse tensors are lowered to loops. - "sparse-assembler{direct-out}", - "sparsification-and-bufferization", - "sparse-storage-specifier-to-llvm", - # Buffer deallocation pass does not know how to handle realloc. - "func.func(expand-realloc)", - # Bufferize. - "func.func(scf-bufferize)", - "func.func(tm-tensor-bufferize)", - "func.func(empty-tensor-to-alloc-tensor)", - "func.func(linalg-bufferize)", - "func-bufferize", - "arith-bufferize", - "refback-mlprogram-bufferize", - "func.func(tensor-bufferize)", - "func.func(finalizing-bufferize)", - "func.func(buffer-deallocation)", - # Buffer-deallocation does not work with the inlined code generated - # by sparse tensor dialect. - "inline", # inline sparse helper methods where useful - # Munge to make it ExecutionEngine compatible. - # Specifically, we rewrite calling convention boundaries to be in terms - # of unranked memref, and we rewrite the return to actually be a - # callback that consumes the return (the final munged function always - # returns void at the C level -- we get the return value by providing the - # callback). - "refback-munge-calling-conventions", - # Insert global variable and instruction sequence for getting the next - # global seed used in stateful rng. - # Lower to LLVM - "func.func(tm-tensor-to-loops)", - "func.func(refback-munge-memref-copy)", - "func.func(convert-linalg-to-loops)", - "func.func(lower-affine)", - "convert-scf-to-cf", - "func.func(refback-expand-ops-for-llvm)", - "func.func(arith-expand)", - "func.func(convert-math-to-llvm)", - # Handle some complex mlir::math ops (e.g. atan2) - "convert-math-to-libm", - "expand-strided-metadata", - "finalize-memref-to-llvm", - "lower-affine", - "convert-bufferization-to-memref", - "finalize-memref-to-llvm", - "func.func(convert-arith-to-llvm)", - "convert-vector-to-llvm", - "convert-func-to-llvm", - "convert-cf-to-llvm", - "convert-complex-to-llvm", - "reconcile-unrealized-casts", - ] - ) - + ")" -) +def lowering_pipeline(generate_runtime_verification: bool): + passes = [ + # Apply some optimizations. It would be great if MLIR had more useful + # optimizations that worked out of the box here. + # Note: When measured, this doesn't seem to actually help that much + # for the linalg-on-tensors backend. + # This is likely because if things are naturally fusable we usually already + # emit things in that form from the high level (e.g. single linalg-generic). + # Other backends are likely to benefit more. + "func.func(linalg-generalize-named-ops)", + "func.func(linalg-fuse-elementwise-ops)", + "convert-shape-to-std", + # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum + # to ensure operations on sparse tensors are lowered to loops. + "sparse-assembler{direct-out}", + "sparsification-and-bufferization", + "sparse-storage-specifier-to-llvm", + # Buffer deallocation pass does not know how to handle realloc. + "func.func(expand-realloc)", + # Generalize pad and concat after sparse compiler, as they are handled + # differently when the operations involve sparse operand. + "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", + # Bufferize. + "func.func(tm-tensor-bufferize)", + "one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}", + "refback-mlprogram-bufferize", + # "func.func(finalizing-bufferize)", + "func.func(buffer-deallocation)", + # Buffer-deallocation does not work with the inlined code generated + # by sparse tensor dialect. + "inline", # inline sparse helper methods where useful + # Munge to make it ExecutionEngine compatible. + # Specifically, we rewrite calling convention boundaries to be in terms + # of unranked memref, and we rewrite the return to actually be a + # callback that consumes the return (the final munged function always + # returns void at the C level -- we get the return value by providing the + # callback). + "refback-munge-calling-conventions", + # Insert global variable and instruction sequence for getting the next + # global seed used in stateful rng. + # Lower to LLVM + "func.func(tm-tensor-to-loops)", + "func.func(refback-munge-memref-copy)", + "func.func(convert-linalg-to-loops)", + "func.func(lower-affine)", + "convert-scf-to-cf", + ] + if generate_runtime_verification: + passes += ["generate-runtime-verification"] + passes += [ + "func.func(refback-expand-ops-for-llvm)", + "func.func(arith-expand)", + "func.func(convert-math-to-llvm)", + # Handle some complex mlir::math ops (e.g. atan2) + "convert-math-to-libm", + "expand-strided-metadata", + "finalize-memref-to-llvm", + "lower-affine", + "convert-bufferization-to-memref", + "finalize-memref-to-llvm", + "func.func(convert-arith-to-llvm)", + "convert-vector-to-llvm", + "convert-func-to-llvm", + "convert-cf-to-llvm", + "convert-complex-to-llvm", + "reconcile-unrealized-casts", + ] + + return "builtin.module(" + ",".join(passes) + ")" class RefBackendLinalgOnTensorsBackend(LinalgOnTensorsBackend): """Main entry-point for the reference backend.""" - def __init__(self): + def __init__(self, generate_runtime_verification: bool = True): super().__init__() + self.generate_runtime_verification = generate_runtime_verification def compile(self, imported_module: Module): """Compiles an imported module, with a flat list of functions. @@ -229,7 +228,7 @@ def compile(self, imported_module: Module): """ run_pipeline_with_repro_report( imported_module, - LOWERING_PIPELINE, + lowering_pipeline(self.generate_runtime_verification), "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", enable_ir_printing=False, ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py deleted file mode 100644 index 7e12f8b15d7d..000000000000 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py +++ /dev/null @@ -1,50 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import abc -from typing import TypeVar - -import torch - -from torch_mlir.ir import Module - -# A type shared between the result of `OnnxBackend.compile` and the -# input to `OnnxBackend.load`. Each backend will likely have a -# different definition of this type. -CompiledArtifact = TypeVar("CompiledArtifact") - -# A wrapper around a backend-specific loaded program representation -# that uniformly translates the `x.method(...)` interface expected of -# Torch modules into appropriate lower-level operations. -Invoker = TypeVar("Invoker") - - -class OnnxBackend(abc.ABC): - """The interface to an ONNX backend. - - Backends are recommended to raise meaningful exceptions in case of error, - ideally with easy reproduction instructions. - """ - - @abc.abstractmethod - def compile(self, module: Module) -> CompiledArtifact: - """Compile the provided MLIR module into a compiled artifact. - - The module adheres to the ONNX backend contract - (see the VerifyOnnxBackendContract pass). - - The compiled artifact can be any type, but must be correctly - interpreted by the `load` method. - """ - - @abc.abstractmethod - def load(self, artifact: CompiledArtifact) -> Invoker: - """Load the compiled artifact into a uniformly invokable form. - - The compiled artifact is the result of a previous call to `compile`. - - See the description of `Invoker` for the requirements on the returned - type. - """ diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py deleted file mode 100644 index 30129c7510ef..000000000000 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py +++ /dev/null @@ -1,80 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - - -from torch_mlir.compiler_utils import ( - run_pipeline_with_repro_report, - lower_mlir_module, - OutputType, -) -from torch_mlir.ir import * -from torch_mlir.passmanager import * - -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( - RefBackendLinalgOnTensorsBackend, -) - -from .abc import OnnxBackend - -__all__ = [ - "LinalgOnTensorsOnnxBackend", -] - -# The pipeline of func.func passes that lower the ONNX backend contract to the -# Linalg-on-Tensors backend contract accepted by RefBackend. -ONNX_TO_TORCH_FUNC_PIPELINE = ",".join( - [ - "convert-torch-onnx-to-torch", - ] -) - - -class LinalgOnTensorsOnnxBackend(OnnxBackend): - """Main entry-point for the linalg-on-tensors based ONNX backend. - - This currently uses the linalg-on-tensors RefBackend for actual execution. - """ - - def __init__(self): - super().__init__() - self.refbackend = RefBackendLinalgOnTensorsBackend() - - def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the ONNX backend contract. - - Args: - imported_module: The MLIR module consisting of ONNX operations wrapped by - torch.operator. - Returns: - An opaque, backend specific compiled artifact object that can be - passed to `load`. - """ - run_pipeline_with_repro_report( - imported_module, - f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", - "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", - ) - - backend_legal_ops = [ - "aten.flatten.using_ints", - "aten.adaptive_avg_pool1d", - "aten.unflatten.int", - ] - option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" - run_pipeline_with_repro_report( - imported_module, - f"builtin.module(torch-lower-to-backend-contract{option_string})", - "Lowering TorchFX IR -> Torch Backend IR", - ) - - imported_module = lower_mlir_module( - False, OutputType.LINALG_ON_TENSORS, imported_module - ) - compiled_module = self.refbackend.compile(imported_module) - return compiled_module - - def load(self, module): - """Loads a compiled artifact into the runtime.""" - return self.refbackend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/registry.py b/projects/pt1/python/torch_mlir_e2e_test/registry.py index d2116bafe939..a98a6d34e7f8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/registry.py +++ b/projects/pt1/python/torch_mlir_e2e_test/registry.py @@ -15,7 +15,9 @@ _SEEN_UNIQUE_NAMES = set() -def register_test_case(module_factory: Callable[[], torch.nn.Module]): +def register_test_case( + module_factory: Callable[[], torch.nn.Module], timeout_seconds: int = 120 +): """Convenient decorator-based test registration. Adds a `framework.Test` to the global test registry based on the decorated @@ -38,6 +40,7 @@ def decorator(f): unique_name=f.__name__, program_factory=module_factory, program_invoker=f, + timeout_seconds=timeout_seconds, ) ) return f diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 61050de8fd6c..79c743353474 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -23,6 +23,7 @@ [ "func.func(stablehlo-aggressive-simplification)", "stablehlo-legalize-to-linalg", + "stablehlo-convert-to-signless", "canonicalize", ] ) @@ -36,7 +37,10 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend): def __init__(self): super().__init__() - self.refbackend = RefBackendLinalgOnTensorsBackend() + # TOOD: Enable runtime verification and fix found bugs. + self.refbackend = RefBackendLinalgOnTensorsBackend( + generate_runtime_verification=False + ) def compile(self, imported_module: Module): """Compiles an imported module that satisfied the Stablehlo backend contract. diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index dca86870f1ac..f161f59404c0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -15,8 +15,10 @@ "QuantizedSingleLayer_basic", "QuantizedBatchedInputSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", + "RepeatInterleaveModule_basic", "ReduceMinAlongDimUnsignedInt_basic", "ElementwiseToDtypeI64ToUI8Module_basic", + "TimeOutModule_basic", # This test is expected to time out } @@ -41,8 +43,10 @@ def register_all_tests(): from . import elementwise_comparison from . import squeeze from . import slice_like + from . import spectral from . import nll_loss from . import index_select + from . import linalg_algorithms from . import arange from . import constant_alloc from . import threshold @@ -57,3 +61,5 @@ def register_all_tests(): from . import padding from . import diagonal from . import gridsampler + from . import meshgrid + from . import timeout diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py index e209d15b2b0b..5e6e093902c4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/backprop.py @@ -322,3 +322,164 @@ def forward(self, grad, input): @register_test_case(module_factory=lambda: LeakyReluBackwardStaticModule()) def LeakyReluBackwardStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainModule()) +def RreluWithNoiseBackwardTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=True, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardTrainStaticModule()) +def RreluWithNoiseBackwardTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class RreluWithNoiseBackwardEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalModule()) +def RreluWithNoiseBackwardEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseBackwardEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ([3, 4, 5], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + return torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.1, + upper=0.9, + training=False, + self_is_result=False, + ) + + +@register_test_case(module_factory=lambda: RreluWithNoiseBackwardEvalStaticModule()) +def RreluWithNoiseBackwardEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +class RreluWithNoiseForwardBackwardModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, grad, input, noise): + res = torch.ops.aten.rrelu_with_noise_backward( + grad, + input, + noise, + lower=0.4, + upper=0.6, + training=True, + self_is_result=False, + ) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: RreluWithNoiseForwardBackwardModule()) +def RreluWithNoiseForwardBackwardModule_basic(module, tu: TestUtils): + grad = tu.rand(256, 244) + input = tu.rand(256, 244, low=-1.0, high=1.0) + noise = tu.rand(256, 244) + torch.ops.aten.rrelu_with_noise(input, noise, lower=0.4, upper=0.6, training=True) + module.forward(grad, input, noise) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index b483f9d3c689..ec58351f4a24 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -87,6 +87,29 @@ def BmmFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) +class BmmFloat16Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float16, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, lhs, rhs): + return torch.bmm(lhs, rhs) + + +@register_test_case(module_factory=lambda: BmmFloat16Module()) +def BmmFloat16Module_basic(module, tu: TestUtils): + module.forward( + tu.rand(3, 4, 5).to(torch.float16), tu.rand(3, 5, 4).to(torch.float16) + ) + + class BmmIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1011,6 +1034,97 @@ def TensorsConcatModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatComplex64FloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex64, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, a, b, c, d): + return torch.cat([a, b, c, d], 1) + + +@register_test_case(module_factory=lambda: TensorsConcatComplex64FloatModule()) +def TensorsConcatComplex64FloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 1, 4, low=1, high=10).to(torch.complex64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float32), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float16), + ) + + +# ============================================================================== + + +class TensorsConcatComplex128FloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ([-1, -1, -1], torch.float64, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float16, True), + ] + ) + def forward(self, a, b, c, d): + return torch.cat([a, b, c, d], 1) + + +@register_test_case(module_factory=lambda: TensorsConcatComplex128FloatModule()) +def TensorsConcatComplex128FloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float32), + tu.rand(2, 3, 4, low=1, high=10).to(torch.float16), + ) + + +# ============================================================================== + + +class TensorsConcatComplex128IntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.complex128, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.int32, True), + ] + ) + def forward(self, a, b, c): + return torch.cat([a, b, c], 1) + + +@register_test_case(module_factory=lambda: TensorsConcatComplex128IntModule()) +def TensorsConcatComplex128IntModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 1, 4, low=1, high=10).to(torch.complex128), + tu.rand(2, 3, 4, low=1, high=10).to(torch.int64), + tu.rand(2, 3, 4, low=1, high=10).to(torch.int32), + ) + + +# ============================================================================== + + class TensorsConcatNegativeDimModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1115,6 +1229,36 @@ def TensorsConcatNegativeDimStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorsConcatPromoteDTypeStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 2, 4], torch.bool, True), + ([2, 1, 4], torch.int32, True), + ([2, 3, 4], torch.int64, True), + ] + ) + def forward(self, x, y, z): + return torch.cat([x, y, z], dim=-2) + + +@register_test_case(module_factory=lambda: TensorsConcatPromoteDTypeStaticModule()) +def TensorsConcatPromoteDTypeStaticModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(2, 2, 4, low=0, high=2).bool(), + tu.randint(2, 1, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) + + +# ============================================================================== + + class TensorsStackModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1217,6 +1361,184 @@ def TensorsStackPromoteDTypeModule_basic(module, tu: TestUtils): # ============================================================================== +class HstackBasicIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 4], torch.bool, True), + ([2, 3, 4], torch.int32, True), + ([2, 3, 4], torch.int64, True), + ] + ) + def forward(self, x, y, z): + return torch.ops.aten.hstack([x, y, z]) + + +@register_test_case(module_factory=lambda: HstackBasicIntModule()) +def HstackBasicIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(2, 3, 4, low=0, high=2).bool(), + tu.randint(2, 3, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) + + +class HstackBasicFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6, 4], torch.int32, True), + ([2, 3, 4], torch.float64, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.hstack([x, y]) + + +@register_test_case(module_factory=lambda: HstackBasicFloatModule()) +def HstackBasicFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 6, 4).int(), + tu.rand(2, 3, 4).double(), + ) + + +class HstackBasicIntFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.hstack([x, y]) + + +@register_test_case(module_factory=lambda: HstackBasicIntFloatModule()) +def HstackBasicIntFloatModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, 6, 4, 2, low=1, high=50).int(), + tu.rand(4, 3, 4, 2), + ) + + +class HstackBasicComplexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.complex64, True), + ([-1, -1, -1, -1], torch.complex128, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.hstack([x, y]) + + +@register_test_case(module_factory=lambda: HstackBasicComplexModule()) +def HstackBasicComplexModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(4, 6, 4, 2).type(torch.complex64), + tu.rand(4, 3, 4, 2).type(torch.complex128), + ) + + +# ============================================================================== + + +class ColumnStackBasicIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 4], torch.bool, True), + ([2, 3, 4], torch.int32, True), + ([2, 3, 4], torch.int64, True), + ] + ) + def forward(self, x, y, z): + return torch.ops.aten.column_stack([x, y, z]) + + +@register_test_case(module_factory=lambda: ColumnStackBasicIntModule()) +def ColumnStackBasicIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(2, 3, 4, low=0, high=2).bool(), + tu.randint(2, 3, 4, low=0, high=100).int(), + tu.randint(2, 3, 4, low=0, high=100).long(), + ) + + +# ============================================================================== + + +class ColumnStack1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.column_stack([x, y]) + + +@register_test_case(module_factory=lambda: ColumnStack1dModule()) +def ColumnStack1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + + +# ============================================================================== + + +class ColumnStack0dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ([], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.column_stack([x, y]) + + +@register_test_case(module_factory=lambda: ColumnStack0dModule()) +def ColumnStack0dModule_basic(module, tu: TestUtils): + module.forward(torch.tensor(4.0), torch.tensor(1.0)) + + +# ============================================================================== + + class GatherModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1715,7 +2037,7 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils): # ============================================================================== -class SoftplusModule(torch.nn.Module): +class SafeSoftmaxModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1723,22 +2045,22 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), ] ) - def forward(self, x): - return torch.ops.aten.softplus(x) + def forward(self, tensor): + return torch.ops.aten._safe_softmax(tensor, dim=0) -@register_test_case(module_factory=lambda: SoftplusModule()) -def SoftplusModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 3)) +@register_test_case(module_factory=lambda: SafeSoftmaxModule()) +def SafeSoftmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) # ============================================================================== -class HardsigmoidModule(torch.nn.Module): +class SafeSoftmaxNonNoneDtypeModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1746,22 +2068,68 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), ] ) - def forward(self, x): - return torch.ops.aten.hardsigmoid(x) + def forward(self, tensor): + return torch.ops.aten._safe_softmax(tensor, dim=2, dtype=torch.float64) -@register_test_case(module_factory=lambda: HardsigmoidModule()) -def HardsigmoidModule_basic(module, tu: TestUtils): - module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]])) +@register_test_case(module_factory=lambda: SafeSoftmaxNonNoneDtypeModule()) +def SafeSoftmaxNonNoneDtypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2, 4)) # ============================================================================== -class HardsigmoidRandomModule(torch.nn.Module): +class SoftplusModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.softplus(x) + + +@register_test_case(module_factory=lambda: SoftplusModule()) +def SoftplusModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + + +# ============================================================================== + + +class HardsigmoidModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.hardsigmoid(x) + + +@register_test_case(module_factory=lambda: HardsigmoidModule()) +def HardsigmoidModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[4.0, -5.0, 3.0], [2.9, -1.5, -3.0]])) + + +# ============================================================================== + + +class HardsigmoidRandomModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1807,6 +2175,54 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastToDifferentRankStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 8], torch.float32, True), + ] + ) + def forward(self, x): + return torch.broadcast_to(x, [1, 2, 8]) + + +@register_test_case(module_factory=lambda: BroadcastToDifferentRankStaticModule()) +def BroadcastToDifferentRankStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 8)) + + +# ============================================================================== + + +class BroadcastToDifferentRankNotOneStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 8], torch.float32, True), + ] + ) + def forward(self, x): + return torch.broadcast_to(x, [10, 2, 8]) + + +@register_test_case(module_factory=lambda: BroadcastToDifferentRankNotOneStaticModule()) +def BroadcastToDifferentRankNotOneStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 8)) + + +# ============================================================================== + + class BroadcastToSameRankStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1909,6 +2325,54 @@ def BroadcastDynamicDimModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastDifferentRankWithMinusOneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 1, 8], torch.float32, True), + ] + ) + def forward(self, x): + return torch.broadcast_to(x, [10, -1, -1, -1]) + + +@register_test_case(module_factory=lambda: BroadcastDifferentRankWithMinusOneModule()) +def BroadcastDifferentRankWithMinusOneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 8)) + + +# ============================================================================== + + +class BroadcastDifferentRankSameFinalShapeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 1, 8], torch.float32, True), + ] + ) + def forward(self, x): + return torch.broadcast_to(x, [1, -1, -1, -1]) + + +@register_test_case(module_factory=lambda: BroadcastDifferentRankSameFinalShapeModule()) +def BroadcastDifferentRankSameFinalShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 8)) + + +# ============================================================================== + + class RollModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1952,6 +2416,79 @@ def RepeatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1, 2)) +# ============================================================================== +class RepeatInterleaveModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.int, True), + ] + ) + def forward(self, x): + z = torch.ops.aten.repeat_interleave(x, output_size=10) + y = torch.ops.aten.repeat_interleave(x) + return z, y + + +@register_test_case(module_factory=lambda: RepeatInterleaveModule()) +def RepeatInterleaveModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([3, 1, 2, 4], dtype=torch.int)) + + +# ============================================================================== +class RepeatInterleaveFillModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1], torch.int, True), + ] + ) + def forward(self, x): + x = torch.ops.aten.fill_(x, 2) + x = torch.ops.aten.expand(x, [16]) + return torch.ops.aten.repeat_interleave(x) + + +@register_test_case(module_factory=lambda: RepeatInterleaveFillModule()) +def RepeatInterleaveFillModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([1], dtype=torch.int)) + + +# ============================================================================== + + +class RepeatInterleaveStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + x = torch.ones((10), dtype=torch.int).fill_(3) + z = torch.ops.aten.repeat_interleave(x, output_size=30) + return z + + +@register_test_case(module_factory=lambda: RepeatInterleaveStaticModule()) +def RepeatInterleaveStaticModule_basic(module, tu: TestUtils): + module.forward() + + # ============================================================================== @@ -4111,25 +4648,123 @@ def IntImplicitModule_basic(module, tu: TestUtils): # ============================================================================== -class PowIntFloat(torch.nn.Module): +class PowModule(torch.nn.Module): def __init__(self): super().__init__() - self.value = 2 - self.power_value = 3.0 @export @annotate_args( [ None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), ] ) - def forward(self): - return torch.ops.aten.pow(self.value, self.power_value) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) -@register_test_case(module_factory=lambda: IntFloatModule()) +@register_test_case(module_factory=lambda: PowModule()) +def PowFloatFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class PowIntFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) + + +@register_test_case(module_factory=lambda: PowIntFloatModule()) def PowIntFloatModule_basic(module, tu: TestUtils): - module.forward() + module.forward(tu.randint(3, 4, 5, dtype=torch.int32), tu.rand(3, 4, 5)) + + +# ============================================================================== + + +class PowFloatIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) + + +@register_test_case(module_factory=lambda: PowFloatIntModule()) +def PowFloatIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5), tu.randint(3, 4, 5, dtype=torch.int32)) + + +# ============================================================================== + + +class PowIntIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int32, True), + ([-1, -1, -1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.pow(x, y) + + +@register_test_case(module_factory=lambda: PowIntIntModule()) +def PowIntIntModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, 5, high=10, dtype=torch.int32), + tu.randint(3, 4, 5, high=20, dtype=torch.int32), + ) + + +# ============================================================================== + + +class IsInfiniteModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.isfinite(x) + + +@register_test_case(module_factory=lambda: IsInfiniteModule()) +def IsInfiniteModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([-torch.inf, torch.inf, torch.nan, -2.3, 0.0, 1.5])) # ============================================================================== @@ -4592,7 +5227,7 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils): # ============================================================================== -class AtenToDeviceModule(torch.nn.Module): +class CumprodModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4600,7 +5235,91 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, val): + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumprod(val, ones.item()) + + +@register_test_case(module_factory=lambda: CumprodModule()) +def CumprodModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodStaticModule()) +def CumprodStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticNegativeDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, dim=-1) + + +@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule()) +def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodInputDtypeInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.int32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module()) +def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 7, 4).to(torch.int32)) + + +# ============================================================================== + + +class AtenToDeviceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), ] ) def forward(self, val): @@ -4614,6 +5333,27 @@ def AtenToDeviceModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4)) +# ============================================================================== +class AtenToDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2], torch.bool, True), + ] + ) + def forward(self, val): + return torch.ops.aten.to(val, dtype=torch.int32, non_blocking=False) + + +@register_test_case(module_factory=lambda: AtenToDtypeModule()) +def AtenToDtypeModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([True, False], dtype=torch.bool)) + + # ============================================================================== @@ -5012,6 +5752,31 @@ class ScaledDotProductAttentionSameModule(torch.nn.Module): def __init__(self): super().__init__() + @export + @annotate_args( + [ + None, + ([1, 5, 5], torch.float32, True), + ([1, 5, 5], torch.float32, True), + ([1, 5, 5], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention(query, key, value) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule()) +def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): + query = torch.randn(1, 5, 5, dtype=torch.float32) + key = torch.randn(1, 5, 5, dtype=torch.float32) + value = torch.randn(1, 5, 5, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionSameDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export @annotate_args( [ @@ -5025,8 +5790,35 @@ def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) -@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule()) -def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameDynamicModule()) +def ScaledDotProductAttentionSameDynamicModule_basic(module, tu: TestUtils): + query = torch.randn(1, 5, 5, dtype=torch.float32) + key = torch.randn(1, 5, 5, dtype=torch.float32) + value = torch.randn(1, 5, 5, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionSameCausalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionSameCausalModule()) +def ScaledDotProductAttentionSameCausalModule_basic(module, tu: TestUtils): query = torch.randn(1, 5, 5, dtype=torch.float32) key = torch.randn(1, 5, 5, dtype=torch.float32) value = torch.randn(1, 5, 5, dtype=torch.float32) @@ -5041,9 +5833,9 @@ def __init__(self): @annotate_args( [ None, - ([2, 3, 8, 4], torch.float32, True), - ([2, 3, 16, 4], torch.float32, True), - ([2, 3, 16, 4], torch.float32, True), + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), ] ) def forward(self, query, key, value): @@ -5052,9 +5844,148 @@ def forward(self, query, key, value): @register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule()) def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): - query = torch.randn(2, 3, 8, 4, dtype=torch.float32) - key = torch.randn(2, 3, 16, 4, dtype=torch.float32) - value = torch.randn(2, 3, 16, 4, dtype=torch.float32) + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionDifferentCausalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + +@register_test_case( + module_factory=lambda: ScaledDotProductAttentionDifferentCausalModule() +) +def ScaledDotProductAttentionDifferentDynamicCausalModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionDifferentDynamicCausalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, -1, 16], torch.float32, True), + ([2, 3, -1, 16], torch.float32, True), + ([2, 3, -1, 20], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + +@register_test_case( + module_factory=lambda: ScaledDotProductAttentionDifferentDynamicCausalModule() +) +def ScaledDotProductAttentionDifferentCausalModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + module.forward(query, key, value) + + +class ScaledDotProductAttentionMaskModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), + ([2, 1, 8, 12], torch.float32, True), + ] + ) + def forward(self, query, key, value, mask): + return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionMaskModule()) +def ScaledDotProductAttentionMaskModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + mask = torch.randn(2, 1, 8, 12, dtype=torch.float32) + module.forward(query, key, value, mask) + + +class ScaledDotProductAttentionBoolMaskModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 8, 16], torch.float32, True), + ([2, 3, 12, 16], torch.float32, True), + ([2, 3, 12, 20], torch.float32, True), + ([2, 3, 8, 12], torch.bool, True), + ] + ) + def forward(self, query, key, value, mask): + return torch.ops.aten.scaled_dot_product_attention(query, key, value, mask) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionBoolMaskModule()) +def ScaledDotProductAttentionBoolMaskModule_basic(module, tu: TestUtils): + query = torch.randn(2, 3, 8, 16, dtype=torch.float32) + key = torch.randn(2, 3, 12, 16, dtype=torch.float32) + value = torch.randn(2, 3, 12, 20, dtype=torch.float32) + mask = torch.randn(2, 3, 8, 12, dtype=torch.float32) > 0.5 + module.forward(query, key, value, mask) + + +class ScaledDotProductAttentionGQAModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 32, 3, 8], torch.float32, True), + ([4, 8, 3, 8], torch.float32, True), + ([4, 8, 3, 8], torch.float32, True), + ] + ) + def forward(self, query, key, value): + return torch.ops.aten.scaled_dot_product_attention( + query, key, value, enable_gqa=True + ) + + +@register_test_case(module_factory=lambda: ScaledDotProductAttentionGQAModule()) +def ScaledDotProductAttentionGQAModule_basic(module, tu: TestUtils): + query = torch.randn(4, 32, 3, 8, dtype=torch.float32) + key = torch.randn(4, 8, 3, 8, dtype=torch.float32) + value = torch.randn(4, 8, 3, 8, dtype=torch.float32) module.forward(query, key, value) @@ -5135,6 +6066,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorAlloc1dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.int, True), + ] + ) + def forward(self, x): + res = torch.tensor([x.shape[0]]) + return res + + +@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule()) +def TensorAlloc1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6)) + + +# ============================================================================== + + class ScalarTensorFloat32Module(torch.nn.Module): def __init__(self): super().__init__() @@ -5547,3 +6502,260 @@ def forward(self, x): @register_test_case(module_factory=lambda: CloneModule()) def CloneModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 5)) + + +# ============================================================================== + + +class AtenKthvalueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2, 6, 3], torch.int32, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=False) + + +@register_test_case(module_factory=lambda: AtenKthvalueModule()) +def AtenKthvalueModule_basic(module, tu: TestUtils): + module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3)) + + +# ============================================================================== + + +class AtenKthvalueKeepDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2, 6, 3], torch.int32, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=4, dim=1, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueKeepDimModule()) +def AtenKthvalueKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.randperm(2 * 6 * 3, dtype=torch.int32).reshape(2, 6, 3)) + + +# ============================================================================== + + +class AtenKthvalueDynamicDimsModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.int32, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=6, dim=2, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueDynamicDimsModule()) +def AtenKthvalueDynamicDimsModule_basic(module, tu: TestUtils): + module.forward(torch.randperm(4 * 2 * 8 * 3, dtype=torch.int32).reshape(4, 2, 8, 3)) + + +# ============================================================================== + + +class AtenKthvalueFloat64Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([4, 2, 8, 3], torch.float64, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=3, dim=0, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueFloat64Module()) +def AtenKthvalueFloat64Module_basic(module, tu: TestUtils): + module.forward( + torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3) + ) + + +# ============================================================================== + + +class AtenKthvalueFloat64DynamicDimsModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.float64, True)]) + def forward(self, x): + return torch.ops.aten.kthvalue(x, k=3, dim=3, keepdim=True) + + +@register_test_case(module_factory=lambda: AtenKthvalueFloat64DynamicDimsModule()) +def AtenKthvalueFloat64DynamicDimsModule_basic(module, tu: TestUtils): + module.forward( + torch.randperm(4 * 2 * 8 * 3, dtype=torch.float64).reshape(4, 2, 8, 3) + ) + + +# ============================================================================== + + +class UnfoldModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.unfold = torch.nn.Unfold(kernel_size=(2, 3)) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input): + return self.unfold(input) + + +@register_test_case(module_factory=lambda: UnfoldModule()) +def UnfoldModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 3, 4)) + + +# ============================================================================== + + +class AtenPolarFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.unfold = torch.nn.Unfold(kernel_size=(2, 3)) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, abs_, angle): + return torch.ops.aten.polar(torch.ops.aten.abs(abs_), angle) + + +@register_test_case(module_factory=lambda: AtenPolarFloatModule()) +def AtenPolarFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5, 3, 4), tu.rand(2, 5, 3, 4)) + + +# ============================================================================== + + +class AtenPolarDoubleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.unfold = torch.nn.Unfold(kernel_size=(2, 3)) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ([-1, -1, -1, -1], torch.float64, True), + ] + ) + def forward(self, abs_, angle): + return torch.ops.aten.polar(torch.ops.aten.abs(abs_), angle) + + +@register_test_case(module_factory=lambda: AtenPolarDoubleModule()) +def AtenPolarDoubleModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64) + ) + + +# ============================================================================== + + +class AtenNonzero1DDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.ops.aten.nonzero(x) + + +@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule()) +def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) + + +# ============================================================================== + + +class AtenSymConstrainRange(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + torch.ops.aten.sym_constrain_range(a, max=5) + return a + + +@register_test_case(module_factory=lambda: AtenSymConstrainRange()) +def AtenSymConstrainRange_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) + + +# ============================================================================== + + +class AtenSymConstrainRangeForSize(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10) + return a + + +@register_test_case(module_factory=lambda: AtenSymConstrainRangeForSize()) +def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) + + +# ============================================================================== +class Aten_AssertScalar(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int, True)]) + def forward(self, x): + a = x.item() + assert_msg = "Assertion failed for condition x.item() > 3" + torch.ops.aten._assert_scalar(a > 3, assert_msg) + return a + + +@register_test_case(module_factory=lambda: Aten_AssertScalar()) +def Aten_AssertScalar_basic(module, tu: TestUtils): + module.forward(torch.tensor(4)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 8ce0a44d7fd4..551845ff3e87 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1546,6 +1546,27 @@ def ZeroInt64Module_basic(module, tu: TestUtils): # ============================================================================== +class NewEmptyModuleBool(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.bool, True), + ] + ) + def forward(self, a): + return torch.ops.aten.new_empty(a, [3, 4]).fill_(0) + + +@register_test_case(module_factory=lambda: NewEmptyModuleBool()) +def NewEmptyModuleBool_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3, high=2).to(dtype=torch.bool)) + + class NewEmptyModuleDefaultDtype(torch.nn.Module): def __init__(self): super().__init__() @@ -1928,7 +1949,7 @@ def __init__(self): @annotate_args( [ None, - ([2, 3, 4], torch.float32, True), + ([4, 3, 4], torch.float32, True), ] ) def forward(self, a): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index e99525c32d88..7cf6226890f6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -11,6 +11,88 @@ # ============================================================================== +class Conv1dNoPaddingModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 768, 768], torch.float32, True), + ([768, 768, 1], torch.float32, True), + ([768], torch.float32, True), + ] + ) + def forward(self, x, weights, bias): + return torch.ops.aten.convolution( + x, weights, bias, [1], [0], [1], False, [0], 1 + ) + + +@register_test_case(module_factory=lambda: Conv1dNoPaddingModule()) +def Conv1dNoPaddingModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 768, 768), tu.rand(768, 768, 1), torch.ones(768)) + + +# ============================================================================== + + +class Conv1dNoPaddingTransposeModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 768, 768], torch.float32, True), + ([768, 768, 1], torch.float32, True), + ([768], torch.float32, True), + ] + ) + def forward(self, x, weights, bias): + return torch.ops.aten.convolution(x, weights, bias, [1], [0], [1], True, [0], 1) + + +@register_test_case(module_factory=lambda: Conv1dNoPaddingTransposeModule()) +def Conv1dNoPaddingTransposeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 768, 768), tu.rand(768, 768, 1), torch.ones(768)) + + +# ============================================================================== + + +class Conv1dNoPaddingGroupModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 3072, 12], torch.float32, True), + ([768, 768, 1], torch.float32, True), + ([768], torch.float32, True), + ] + ) + def forward(self, x, weights, bias): + return torch.ops.aten.convolution( + x, weights, bias, [1], [0], [1], False, [0], 4 + ) + + +@register_test_case(module_factory=lambda: Conv1dNoPaddingGroupModule()) +def Conv1dNoPaddingGroupModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3072, 12), tu.rand(768, 768, 1), torch.ones(768)) + + +# ============================================================================== + + class Conv2dNoPaddingModule(torch.nn.Module): def __init__(self): super().__init__() @@ -191,6 +273,54 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( module.forward(tu.rand(5, 4, 10, 20)) +class Conv2dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding="same") + self.train(False) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) +def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): + t = tu.rand(5, 2, 10, 20) + module.forward(t) + + +class Conv2dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding="valid") + self.train(False) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) +def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): + t = tu.rand(5, 2, 10, 20) + module.forward(t) + + # ============================================================================== @@ -600,6 +730,41 @@ def _ConvolutionDeprecated2DCudnnModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 3, 10, 10), tu.rand(3, 3, 2, 2)) +# ============================================================================== + + +class Convolution2DGroupsStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 32, 4, 4], torch.float32, True), + ([32, 8, 3, 3], torch.float32, True), + ([32], torch.float32, True), + ] + ) + def forward(self, x, weight, bias): + return torch.ops.aten.convolution( + x, + weight, + bias=bias, + stride=[3, 3], + padding=[2, 2], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=4, + ) + + +@register_test_case(module_factory=lambda: Convolution2DGroupsStatic()) +def Convolution2DGroupsStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 32, 4, 4), tu.rand(32, 8, 3, 3), torch.ones(32)) + + class ConvolutionModule2DGroups(torch.nn.Module): def __init__(self): super().__init__() @@ -760,6 +925,66 @@ def ConvolutionModule2DTransposeNonUnitOutputPadding_basic(module, tu: TestUtils module.forward(tu.rand(1, 2, 4, 4), tu.rand(2, 2, 3, 3)) +class Conv_Transpose1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d( + inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dModule()) +def Conv_Transpose1dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2)) + + +class Conv_Transpose1dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 6], torch.float32, True), + ([2, 5, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose1d( + inputVec, + weight, + bias=None, + stride=[2], + padding=[1], + dilation=[1], + output_padding=[0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose1dStaticModule()) +def Conv_Transpose1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 6), tu.rand(2, 5, 2)) + + class Conv_Transpose2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -790,6 +1015,96 @@ def Conv_Transpose2dModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) +class Conv_Transpose2dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6], torch.float32, True), + ([2, 5, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose2d( + inputVec, + weight, + bias=None, + stride=[2, 2], + padding=[1, 1], + dilation=[1, 1], + output_padding=[0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose2dStaticModule()) +def Conv_Transpose2dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6), tu.rand(2, 5, 2, 2)) + + +class Conv_Transpose3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d( + inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dModule()) +def Conv_Transpose3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2)) + + +class Conv_Transpose3dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 2, 5, 6, 7], torch.float32, True), + ([2, 5, 2, 2, 2], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv_transpose3d( + inputVec, + weight, + bias=None, + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + output_padding=[0, 0, 0], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv_Transpose3dStaticModule()) +def Conv_Transpose3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 5, 6, 7), tu.rand(2, 5, 2, 2, 2)) + + class UpSampleNearest2d(torch.nn.Module): def __init__(self): super().__init__() @@ -917,6 +1232,117 @@ def Conv1dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv1dDepthwiseWithPaddingDilationStrideStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.float32, True), + ([4, 1, 3], torch.float32, True), + ] + ) + def forward(self, inputVec, weight): + return torch.ops.aten.conv1d( + inputVec, weight, bias=None, stride=[1], padding=[4], dilation=[1], groups=4 + ) + + +@register_test_case( + module_factory=lambda: Conv1dDepthwiseWithPaddingDilationStrideStaticModule() +) +def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(4, 1, 3) + module.forward(inputVec, weight) + + +class Conv1dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + torch.manual_seed(0) + self.conv = torch.nn.Conv1d(2, 10, 3, bias=False, padding="same") + self.train(False) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.conv(x) + + +@register_test_case(module_factory=lambda: Conv1dWithSamePaddingModule()) +def Conv1dWithSamePaddingModule_basic(module, tu: TestUtils): + t = tu.rand(5, 2, 10) + module.forward(t) + + +class Conv1dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, + weight, + bias=bias, + stride=[1], + padding="valid", + dilation=[1], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv1dWithValidPaddingModule()) +def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + +class Conv1dGroupModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d( + inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=2 + ) + + +@register_test_case(module_factory=lambda: Conv1dGroupModule()) +def Conv1dGroupModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 4, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -983,6 +1409,72 @@ def Conv3dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) +class Conv3dWithSamePaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d( + inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding="same", + dilation=[1, 1, 1], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv3dWithSamePaddingModule()) +def Conv3dWithSamePaddingModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + +class Conv3dWithValidPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d( + inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding="valid", + dilation=[1, 1, 1], + groups=1, + ) + + +@register_test_case(module_factory=lambda: Conv3dWithValidPaddingModule()) +def Conv3dWithValidPaddingModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + + class ConvTbcModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1006,55 +1498,116 @@ def ConvTbcModule_basic(module, tu: TestUtils): module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) -class Conv2dQInt8Module(torch.nn.Module): - def __init__(self): +# For DQ-Q fake quantization ops +import torch.ao.quantization.fx._decomposed + + +class Conv2dQInt8ModuleBase(torch.nn.Module): + def __init__(self, groups=1): + self.groups = groups super().__init__() + def _forward(self, input, weight, bias): + input = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + input, 0.01, 7, -128, 127, torch.int8 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + weight, 0.01, 3, -128, 127, torch.int8 + ) + bias = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + bias, 1, 0, -1000, 1000, torch.int32 + ) + + conv = torch.ops.aten.conv2d( + input, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=self.groups, + ) + + # Use int32 to avoid overflows + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + conv, 1, 0, -(2**31), 2**31 - 1, torch.int32 + ) + + +class Conv2dQInt8ModuleDyn(Conv2dQInt8ModuleBase): @export @annotate_args( [ None, ([-1, -1, -1, -1], torch.int8, True), ([-1, -1, -1, -1], torch.int8, True), - ([-1], torch.float, True), + ([-1], torch.int32, True), ] ) def forward(self, inputVec, weight, bias): - inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7) - inputVec = torch.dequantize(inputVec) + return self._forward(inputVec, weight, bias) - weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3) - weight = torch.dequantize(weight) - bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) - bias = torch.dequantize(bias) +class Conv2dQInt8ModuleStatic(Conv2dQInt8ModuleBase): + @export + @annotate_args( + [ + None, + ([2, 3, 12, 12], torch.int8, True), + ([3, 1, 5, 3], torch.int8, True), + ([3], torch.int32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return self._forward(inputVec, weight, bias) - return torch.ops.aten.conv2d( - inputVec, - weight, - bias=bias, - stride=[1, 1], - padding=[0, 0], - dilation=[1, 1], - groups=1, - ) + +class Conv2dQInt8ModuleStatic_MoreOutChannels(Conv2dQInt8ModuleBase): + @export + @annotate_args( + [ + None, + ([2, 3, 12, 12], torch.int8, True), + ([6, 1, 5, 3], torch.int8, True), + ([6], torch.int32, True), + ] + ) + def forward(self, inputVec, weight, bias): + return self._forward(inputVec, weight, bias) -@register_test_case(module_factory=lambda: Conv2dQInt8Module()) +@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn()) def Conv2dQInt8Module_basic(module, tu: TestUtils): inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8) weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) - bias = torch.rand(3) + bias = tu.randint(3, low=-1000, high=1000).to(torch.int32) module.forward(inputVec, weight, bias) -N = 10 -Cin = 5 -Cout = 7 -Hin = 10 -Win = 8 -Hker = 3 -Wker = 2 +@register_test_case(module_factory=lambda: Conv2dQInt8ModuleDyn(groups=2)) +def Conv2dQInt8Module_grouped(module, tu: TestUtils): + inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8) + bias = tu.randint(6, low=-1000, high=1000).to(torch.int32) + module.forward(inputVec, weight, bias) + + +@register_test_case(module_factory=lambda: Conv2dQInt8ModuleStatic(groups=3)) +def Conv2dQInt8Module_depthwise(module, tu: TestUtils): + inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8) + weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8) + bias = tu.randint(3, low=-1000, high=1000).to(torch.int32) + module.forward(inputVec, weight, bias) + + +@register_test_case( + module_factory=lambda: Conv2dQInt8ModuleStatic_MoreOutChannels(groups=3) +) +def Conv2dQInt8Module_not_depthwise(module, tu: TestUtils): + inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8) + weight = tu.randint(6, 1, 5, 3, low=-128, high=127).to(torch.int8) + bias = tu.randint(6, low=-1000, high=1000).to(torch.int32) + module.forward(inputVec, weight, bias) class ConvTranspose2DQInt8Module(torch.nn.Module): @@ -1072,16 +1625,17 @@ def __init__(self): ] ) def forward(self, input, weight, bias): - qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25) - qinput = torch.dequantize(qinput) - qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50) - qweight = torch.dequantize(qweight) - qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) - qbias = torch.dequantize(qbias) - qz = torch.ops.aten.convolution( - qinput, - qweight, - bias=qbias, + input = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + input, 0.01, -25, -128, 127, torch.int8 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + weight, 0.01, 50, -128, 127, torch.int8 + ) + + res = torch.ops.aten.convolution( + input, + weight, + bias=bias, stride=[2, 1], padding=[1, 1], dilation=[1, 1], @@ -1089,13 +1643,202 @@ def forward(self, input, weight, bias): output_padding=[0, 0], groups=1, ) - return qz + + # Use int32 to avoid overflows + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + res, 1, 0, -(2**31), 2**31 - 1, torch.int32 + ) @register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + N = 10 + Cin = 5 + Cout = 7 + Hin = 10 + Win = 8 + Hker = 3 + Wker = 2 module.forward( tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), torch.rand(Cout), ) + + +class Conv2dQInt8PerChannelModuleBase(torch.nn.Module): + def __init__(self, groups=1): + self.groups = groups + super().__init__() + + def _forward(self, inputVec, weight, scales, zeropoints, bias): + inputVec = torch.ops.quantized_decomposed.dequantize_per_tensor.default( + inputVec, 0.01, 7, -128, 127, torch.int8 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( + weight, scales, zeropoints, 0, -128, 127, torch.int8 + ) + + conv = torch.ops.aten.conv2d( + inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=self.groups, + ) + + # Use int32 to avoid overflows + return torch.ops.quantized_decomposed.quantize_per_tensor.default( + conv, 1, 0, -(2**31), 2**31 - 1, torch.int32 + ) + + +class Conv2dQInt8PerChannelModuleDyn(Conv2dQInt8PerChannelModuleBase): + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ([-1], torch.int8, True), + ([-1], torch.float, True), + ] + ) + def forward(self, inputVec, weight, scales, zeropoints, bias): + return self._forward(inputVec, weight, scales, zeropoints, bias) + + +class Conv2dQInt8PerChannelModuleStatic(Conv2dQInt8PerChannelModuleBase): + @export + @annotate_args( + [ + None, + ([2, 3, 12, 12], torch.int8, True), + ([3, 1, 5, 3], torch.int8, True), + ([3], torch.float, True), + ([3], torch.int8, True), + ([3], torch.float, True), + ] + ) + def forward(self, inputVec, weight, scales, zeropoints, bias): + return self._forward(inputVec, weight, scales, zeropoints, bias) + + +@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleDyn()) +def Conv2dQInt8PerChannelModule_basic(module, tu: TestUtils): + inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) + scales = tu.rand(3) + zeropoints = tu.rand(3).to(torch.int8) + bias = torch.rand(3) + module.forward(inputVec, weight, scales, zeropoints, bias) + + +@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleDyn(groups=2)) +def Conv2dQInt8PerChannelModule_grouped(module, tu: TestUtils): + inputVec = tu.randint(2, 8, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(6, 4, 3, 2, low=-128, high=127).to(torch.int8) + scales = tu.rand(6) + zeropoints = tu.rand(6).to(torch.int8) + bias = torch.rand(6) + module.forward(inputVec, weight, scales, zeropoints, bias) + + +@register_test_case(module_factory=lambda: Conv2dQInt8PerChannelModuleStatic(groups=3)) +def Conv2dQInt8PerChannelModule_depthwise(module, tu: TestUtils): + inputVec = tu.randint(2, 3, 12, 12, low=-128, high=127).to(torch.int8) + weight = tu.randint(3, 1, 5, 3, low=-128, high=127).to(torch.int8) + scales = tu.rand(3) + zeropoints = tu.rand(3).to(torch.int8) + bias = torch.rand(3) + module.forward(inputVec, weight, scales, zeropoints, bias) + + +# torchvision.deform_conv2d + +import torchvision + +# This section defines a torch->onnx path for this torchvision op so we can test the onnx paths e2e. + +# Create symbolic function +from torch.onnx.symbolic_helper import parse_args, _get_tensor_sizes + + +@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "b") +def symbolic_deform_conv2d_forward( + g, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask, +): + args = [input, weight, offset, bias] + if use_mask: + args.append(mask) + weight_size = _get_tensor_sizes(weight) + kwargs = { + "dilations_i": [dilation_h, dilation_w], + "group_i": groups, + "kernel_shape_i": weight_size[2:], + "offset_group_i": offset_groups, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": [pad_h, pad_w, pad_h, pad_w], + "strides_i": [stride_h, stride_w], + } + return g.op("DeformConv", *args, **kwargs) + + +# Register symbolic function +from torch.onnx import register_custom_op_symbolic + +register_custom_op_symbolic( + "torchvision::deform_conv2d", symbolic_deform_conv2d_forward, 19 +) + +N = 1 +Cin = 1 +Hin = 7 +Win = 6 +Cout = 1 +Hker = 2 +Wker = 2 +offset_groups = 1 +Hout = 6 +Wout = 5 +offset_dim1 = 2 * offset_groups * Hker * Wker + + +class DeformableConvModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([N, Cin, Hin, Win], torch.float32, True), + ([N, offset_dim1, Hout, Wout], torch.float32, True), + ([Cout, Cin, Hker, Wker], torch.float32, True), + ] + ) + def forward(self, input, offset, weight): + return torchvision.ops.deform_conv2d(input, offset, weight) + + +@register_test_case(module_factory=lambda: DeformableConvModule()) +def DeformConv2D_basic(module, tu: TestUtils): + input = tu.rand(N, Cin, Hin, Win) + offset = tu.rand(N, offset_dim1, Hout, Wout) + weight = tu.rand(Cout, Cin, Hker, Wker) + module.forward(input, offset, weight) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f5e3c9fc4b9b..fae9dbb4365e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -414,6 +414,31 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTernaryStaticShapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 4, 3], torch.float32, True), + ([4, 3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.lerp(a, b, c) + + +@register_test_case(module_factory=lambda: ElementwiseTernaryStaticShapeModule()) +def ElementwiseTernaryStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 4, 3), tu.rand(4, 3), tu.rand(3)) + + +# ============================================================================== + + class ElementwiseAtenWhereSelfModule(torch.nn.Module): def __init__(self): super().__init__() @@ -466,6 +491,29 @@ def ElementwiseWhereSelfModule_basic(module, tu: TestUtils): # ============================================================================== +class FloatPowerTensorTensorStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.float_power(x, torch.tensor(2)) + + +@register_test_case(module_factory=lambda: FloatPowerTensorTensorStaticModule()) +def FloatPowerTensorTensorStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseWhereScalarModule(torch.nn.Module): def __init__(self): super().__init__() @@ -585,6 +633,29 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNanToNumWithNoneModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([3, 4], torch.float32, True)]) + def forward(self, a): + return torch.ops.aten.nan_to_num(a) + + +@register_test_case(module_factory=lambda: ElementwiseNanToNumWithNoneModule()) +def ElementwiseNanToNumWithNoneModule_Basic(module, tu: TestUtils): + module.forward( + torch.tensor( + [ + [float("nan"), 0.0, float("nan"), 1.0], + [float("inf"), 2.0, float("inf"), 3.0], + [float("-inf"), -1.0, float("-inf"), 4.0], + ] + ) + ) + + class ElementwiseNanToNumModule(torch.nn.Module): def __init__(self): super().__init__() @@ -592,7 +663,7 @@ def __init__(self): @export @annotate_args([None, ([3, 4], torch.float32, True)]) def forward(self, a): - return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + return torch.ops.aten.nan_to_num(a, 0.1, 1.0, -1.0) @register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) @@ -600,9 +671,9 @@ def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): module.forward( torch.tensor( [ - [float("nan"), 0.0, float("nan"), 0.0], - [float("inf"), 0.0, float("inf"), 0.0], - [float("-inf"), 0.0, float("-inf"), 0.0], + [float("nan"), 0.0, float("nan"), 1.0], + [float("inf"), 2.0, float("inf"), 3.0], + [float("-inf"), -1.0, float("-inf"), 4.0], ] ) ) @@ -637,6 +708,35 @@ def ElementwiseAddModule_basic(module, tu: TestUtils): # ============================================================================== +# Addition is an interesting special case of a binary op, because under the hood +# it carries a third scalar "alpha" parameter, which needs special handling. +class ElementwiseAddBoolModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.bool, True), + ([4], torch.bool, True), + ] + ) + def forward(self, a, b): + return a + b + + +@register_test_case(module_factory=lambda: ElementwiseAddBoolModule()) +def ElementwiseAddBoolModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([False, False, True, True]), + torch.tensor([False, True, False, False]), + ) + + +# ============================================================================== + + class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1037,6 +1137,196 @@ def ElementwiseCeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRreluTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + res = torch.ops.aten.rrelu(x, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluTrainModule()) +def ElementwiseRreluTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + + +# ============================================================================== + + +class ElementwiseRreluTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1024, 1536], torch.float32, True), + ] + ) + def forward(self, x): + res = torch.ops.aten.rrelu(x, 0.1, 0.9, True) + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule()) +def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1024, 1536)) + + +# ============================================================================== + + +class ElementwiseRreluEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.4, 0.6, False) + + +@register_test_case(module_factory=lambda: ElementwiseRreluEvalModule()) +def ElementwiseRreluEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseRreluEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.rrelu(x, 0.1, 0.9, False) + + +@register_test_case(module_factory=lambda: ElementwiseRreluEvalStaticModule()) +def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseTrainModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + out, out_noise = torch.ops.aten.rrelu_with_noise_functional( + x, noise, 0.2, 0.5, True + ) + return ( + torch.mean(out), + torch.std(out), + torch.mean(out_noise), + torch.std(out_noise), + ) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) +def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): + module.forward(tu.rand(256, 256, low=-1, high=1), tu.rand(256, 256)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseTrainStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([256, 256], torch.float32, True), ([256, 256], torch.float32, True)] + ) + def forward(self, x, noise): + out, out_noise = torch.ops.aten.rrelu_with_noise_functional( + x, noise, 0.4, 0.6, True + ) + return ( + torch.mean(out), + torch.std(out), + torch.mean(out_noise), + torch.std(out_noise), + ) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) +def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(256, 256, low=-1, high=1), tu.rand(256, 256)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] + ) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise_functional(x, noise, 0.4, 0.6, False)[0] + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalModule()) +def ElementwiseRreluWithNoiseEvalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + +class ElementwiseRreluWithNoiseEvalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) + def forward(self, x, noise): + res = torch.ops.aten.rrelu_with_noise_functional(x, noise, 0.4, 0.6, False)[0] + return torch.mean(res), torch.std(res) + + +@register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseEvalStaticModule()) +def ElementwiseRreluWithNoiseEvalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1), tu.rand(5, 3)) + + +# ============================================================================== + + class ElementwiseCeluStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1321,6 +1611,64 @@ def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseFmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.fmax(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmaxModule()) +def ElementwiseFmaxModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + module.forward(tu.rand(4), torch.tensor([1.0, torch.nan, -0.5, -0.3])) + module.forward( + torch.tensor([0.8, torch.nan, torch.nan, -0.3]), + torch.tensor([1.0, torch.nan, -0.4, torch.nan]), + ) + + +# ============================================================================== + + +class ElementwiseFminModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.fmin(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFminModule()) +def ElementwiseFminModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + module.forward(tu.rand(4), torch.tensor([1.0, torch.nan, -0.5, -0.3])) + module.forward( + torch.tensor([0.8, torch.nan, torch.nan, -0.3]), + torch.tensor([1.0, torch.nan, -0.4, torch.nan]), + ) + + +# ============================================================================== + + class ElementwiseMaxOtherModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1458,8 +1806,8 @@ def __init__(self): [ None, ([-1, -1], torch.float32, True), - ([], torch.float32, True), - ([], torch.float32, True), + ([1], torch.float32, True), + ([1], torch.float32, True), ] ) def forward(self, x, min, max): @@ -1488,8 +1836,8 @@ def __init__(self): [ None, ([-1, -1], torch.int64, True), - ([], torch.int64, True), - ([], torch.int64, True), + ([1], torch.int64, True), + ([1], torch.int64, True), ] ) def forward(self, x, min, max): @@ -1541,7 +1889,7 @@ def __init__(self): [ None, ([-1, -1], torch.float32, True), - ([], torch.float32, True), + ([1], torch.float32, True), ] ) def forward(self, x, min): @@ -1565,7 +1913,7 @@ def __init__(self): [ None, ([-1, -1], torch.int64, True), - ([], torch.int64, True), + ([1], torch.int64, True), ] ) def forward(self, x, min): @@ -1672,7 +2020,31 @@ def RsubIntModule_noalpha_basic(module, tu: TestUtils): # ============================================================================== -class RsubInt0d_NumToTensor_Module(torch.nn.Module): +class RsubIntStaticModule_noalpha(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ] + ) + def forward(self, x): + return torch.rsub(x, 2.0) + + +@register_test_case(module_factory=lambda: RsubIntStaticModule_noalpha()) +def RsubIntStaticModule_noalpha_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=100)) + + +# ============================================================================== + + +class RsubInt0d_NumToTensor_Module(torch.nn.Module): def __init__(self): super().__init__() @@ -1812,6 +2184,33 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCreateComplexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.complex(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseCreateComplexModule()) +def ElementwiseCreateComplexModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(4, high=10).type(torch.float32), + tu.randint(4, high=10).type(torch.float32), + ) + + +# ============================================================================== + + class ElementwiseMulTensorComplexModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2501,6 +2900,130 @@ def ElementwiseTruncIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSignbitModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 8], torch.float32, True), + ] + ) + def forward(self, a): + return torch.signbit(a) + + +@register_test_case(module_factory=lambda: ElementwiseSignbitModule()) +def ElementwiseSignbitModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor( + [[-torch.inf, torch.inf, torch.nan, -torch.nan, 2.3, -2.3, 0.0, -0.0]] + ) + ) + + +class ElementwiseSignbitIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.int32, True), + ] + ) + def forward(self, a): + return torch.signbit(a) + + +@register_test_case(module_factory=lambda: ElementwiseSignbitIntModule()) +def ElementwiseSignbitIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseFracModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 6], torch.float32, True), + ] + ) + def forward(self, a): + return torch.frac(a) + + +@register_test_case(module_factory=lambda: ElementwiseFracModule()) +def ElementwiseFracModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[2.3, -2.3, 0.0, -0.0, 2.0, -2.0]])) + + +# ============================================================================== + + +class ElementwiseCopysignModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 1], torch.float32, True), + ([1, 6], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.copysign(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseCopysignModule()) +def ElementwiseCopysignModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[1.0]]), + torch.tensor([[2.3, -2.3, 0.0, -0.0, torch.inf, -torch.inf]]), + ) + + +# ============================================================================== + + +class ElementwiseLdexpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 6], torch.float32, True), + ([1, 1], torch.int64, True), + ] + ) + def forward(self, a, b): + return torch.ldexp(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseLdexpModule()) +def ElementwiseLdexpModule_basic(module, tu: TestUtils): + module.forward( + torch.tensor([[2.3, -2.3, 0.0, -0.0, 4.5, -4.5]]), + torch.tensor([[2]]), + ) + + +# ============================================================================== + + class ElementwiseSignModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2572,6 +3095,49 @@ def ElementwiseSgnModule_basic(module, tu: TestUtils): # ============================================================================== +class Exp2StaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 2], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.exp2(x) + + +@register_test_case(module_factory=lambda: Exp2StaticModule()) +def Exp2StaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 2)) + + +class Exp2StaticIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.int64, True), + ] + ) + def forward(self, x): + return torch.ops.aten.exp2(x) + + +@register_test_case(module_factory=lambda: Exp2StaticIntModule()) +def Exp2StaticIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-20, high=20)) + + +# ============================================================================== + + class ElementwisePowModule(torch.nn.Module): def __init__(self): super().__init__() @@ -3085,7 +3651,7 @@ def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseRemainderScalarModule_Float(torch.nn.Module): +class ElementwiseRemainderScalarModule_Int_Float_NegativeDividend(torch.nn.Module): def __init__(self): super().__init__() @@ -3093,22 +3659,26 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1], torch.float32, True), + ([-1], torch.int32, True), ] ) def forward(self, x): - return torch.remainder(x, 2.0) + return torch.remainder(x, 5.0) -@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float()) -def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): - module.forward(tu.rand(10, 3)) +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDividend() +) +def ElementwiseRemainderScalarModule_Int_Float_NegativeDividend_basic( + module, tu: TestUtils +): + module.forward(tu.randint(30, low=-10, high=10).to(torch.int32)) # ============================================================================== -class ElementwiseRemainderScalarModule_Int(torch.nn.Module): +class ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor(torch.nn.Module): def __init__(self): super().__init__() @@ -3116,22 +3686,26 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1], torch.int32, True), + ([-1], torch.int32, True), ] ) def forward(self, x): - return torch.remainder(x, 2) + return torch.remainder(x, -5.0) -@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int()) -def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 2, high=10).to(torch.int32)) +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor() +) +def ElementwiseRemainderScalarModule_Int_Float_NegativeDivisor_basic( + module, tu: TestUtils +): + module.forward(tu.randint(30, low=-10, high=-1).to(torch.int32)) # ============================================================================== -class ElementwiseRemainderScalarModule_Bool(torch.nn.Module): +class ElementwiseRemainderScalarModule_Float(torch.nn.Module): def __init__(self): super().__init__() @@ -3139,61 +3713,74 @@ def __init__(self): @annotate_args( [ None, - ([-1], torch.bool, True), + ([-1, -1], torch.float32, True), ] ) def forward(self, x): - return torch.remainder(x, 2) + return torch.remainder(x, 2.0) -@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool()) -def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): - module.forward(torch.tensor([True, False, True, True, True])) +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float()) +def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 3)) # ============================================================================== -class ElementwiseFmodTensor_Float(torch.nn.Module): +class ElementwiseRemainderScalarModule_Float_NegativeDividend(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([None, ([-1], torch.float32, True), ([-1], torch.float32, True)]) - def forward(self, x, y): - return torch.fmod(x, y) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.remainder(x, 5.0) -@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Float()) -def ElementwiseFmodTensor_Float_basic(module, tu: TestUtils): - module.forward(tu.rand(100, low=-10, high=10), tu.rand(100, low=-10, high=10)) +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDividend() +) +def ElementwiseRemainderScalarModule_Float_NegativeDividend_basic( + module, tu: TestUtils +): + module.forward(tu.rand(10, 3, low=-10.0, high=10.0)) # ============================================================================== -class ElementwiseFmodTensor_Int_Float(torch.nn.Module): +class ElementwiseRemainderScalarModule_Float_NegativeDivisor(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([None, ([-1], torch.int32, True), ([-1], torch.float32, True)]) - def forward(self, x, y): - return torch.fmod(x, y) + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.remainder(x, -5.0) -@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int_Float()) -def ElementwiseFmodTensor_Int_Float_basic(module, tu: TestUtils): - module.forward( - tu.randint(100, low=-10, high=10).to(torch.int32), - tu.rand(100, low=-10, high=10), - ) +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Float_NegativeDivisor() +) +def ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 3, low=-10.0, high=10.0)) # ============================================================================== -class ElementwiseFmodTensor_Int(torch.nn.Module): +class ElementwiseRemainderScalarModule_Int(torch.nn.Module): def __init__(self): super().__init__() @@ -3201,21 +3788,183 @@ def __init__(self): @annotate_args( [ None, - ([-1], torch.int32, True), - ([-1], torch.int32, True), + ([-1, -1], torch.int32, True), ] ) - def forward(self, x, y): - return torch.fmod(x, y) - + def forward(self, x): + return torch.remainder(x, 2) -@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int()) -def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils): - module.forward( - tu.randint(100, low=0, high=1000).to(torch.int32), - tu.randint(100, low=1, high=1000).to(torch.int32), - ) - # ============================================================================== + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int()) +def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 2, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Int_NegativeDividend(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, x): + return torch.remainder(x, 5) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDividend() +) +def ElementwiseRemainderScalarModule_Int_NegativeDividend_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Int_NegativeDivisor(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, x): + return torch.remainder(x, -5) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Int_NegativeDivisor() +) +def ElementwiseRemainderScalarModule_Int_NegativeDivisor_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 2, low=-10, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Bool(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.remainder(x, 2) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool()) +def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): + module.forward(torch.tensor([True, False, True, True, True])) + + +# ============================================================================== + + +class ElementwiseRemainderScalarModule_Bool_NegativeDivisor(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.bool, True), + ] + ) + def forward(self, x): + return torch.remainder(x, -3) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderScalarModule_Bool_NegativeDivisor() +) +def ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic(module, tu: TestUtils): + module.forward(torch.tensor([True, False, True, True, True])) + + +# ============================================================================== + + +class ElementwiseFmodTensor_Float(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.float32, True), ([-1], torch.float32, True)]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Float()) +def ElementwiseFmodTensor_Float_basic(module, tu: TestUtils): + module.forward(tu.rand(100, low=-10, high=10), tu.rand(100, low=-10, high=10)) + + +# ============================================================================== + + +class ElementwiseFmodTensor_Int_Float(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1], torch.int32, True), ([-1], torch.float32, True)]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int_Float()) +def ElementwiseFmodTensor_Int_Float_basic(module, tu: TestUtils): + module.forward( + tu.randint(100, low=-10, high=10).to(torch.int32), + tu.rand(100, low=-10, high=10), + ) + + +# ============================================================================== + + +class ElementwiseFmodTensor_Int(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.int32, True), + ([-1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int()) +def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils): + module.forward( + tu.randint(100, low=0, high=1000).to(torch.int32), + tu.randint(100, low=1, high=1000).to(torch.int32), + ) + + +# ============================================================================== class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module): @@ -3242,6 +3991,67 @@ def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRemainderTensorModule_Int_Float_NegativeDividend(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDividend() +) +def ElementwiseRemainderTensorModule_Int_Float_NegativeDividend_basic( + module, tu: TestUtils +): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), tu.rand(3, 4, high=10) + ) + + +# ============================================================================== + + +class ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor() +) +def ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic( + module, tu: TestUtils +): + module.forward( + tu.randint(3, 4, low=-10, high=10).to(torch.int32), + tu.rand(3, 4, low=-10, high=-1), + ) + + +# ============================================================================== + + class ElementwiseRemainderTensorModule_Float(torch.nn.Module): def __init__(self): super().__init__() @@ -3266,6 +4076,60 @@ def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRemainderTensorModule_Float_NegativeDividend(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDividend() +) +def ElementwiseRemainderTensorModule_Float_NegativeDividend_basic( + module, tu: TestUtils +): + module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10)) + + +# ============================================================================== + + +class ElementwiseRemainderTensorModule_Float_NegativeDivisor(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderTensorModule_Float_NegativeDivisor() +) +def ElementwiseRemainderTensorModule_Float_NegativeDivisor_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, low=-10, high=-1)) + + +# ============================================================================== + + class ElementwiseRemainderTensorModule_Int(torch.nn.Module): def __init__(self): super().__init__() @@ -3293,6 +4157,64 @@ def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRemainderTensorModule_Int_NegativeDividend(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDividend() +) +def ElementwiseRemainderTensorModule_Int_NegativeDividend_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10, dtype=torch.int32), + tu.randint(3, 4, high=10, dtype=torch.int32), + ) + + +# ============================================================================== + + +class ElementwiseRemainderTensorModule_Int_NegativeDivisor(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int32, True), + ] + ) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case( + module_factory=lambda: ElementwiseRemainderTensorModule_Int_NegativeDivisor() +) +def ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-10, high=10, dtype=torch.int32), + tu.randint(3, 4, low=-10, high=-1, dtype=torch.int32), + ) + + +# ============================================================================== + + class ElementwiseDivTensorFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4194,7 +5116,101 @@ def ElementwiseCloneModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseCloneContiguousModule(torch.nn.Module): +class ElementwiseCloneContiguousModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.clone(x, memory_format=torch.contiguous_format) + + +@register_test_case(module_factory=lambda: ElementwiseCloneContiguousModule()) +def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +# ============================================================================== + + +class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.clone(x, memory_format=torch.channels_last) + + +@register_test_case( + module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule() +) +def ElementwiseCloneChannelsLastMemoryFormatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 5)) + + +# ============================================================================== + + +class LiftFreshCopyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.lift_fresh_copy(x) + + +@register_test_case(module_factory=lambda: LiftFreshCopyModule()) +def LiftFreshCopyModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4)) + + +# ============================================================================== + + +class ElementwiseExpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.exp(a) + + +@register_test_case(module_factory=lambda: ElementwiseExpModule()) +def ElementwiseExpModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseExpIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4202,22 +5218,22 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int32, True), ] ) - def forward(self, x): - return torch.clone(x, memory_format=torch.contiguous_format) + def forward(self, a): + return torch.exp(a) -@register_test_case(module_factory=lambda: ElementwiseCloneContiguousModule()) -def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) +@register_test_case(module_factory=lambda: ElementwiseExpIntModule()) +def ElementwiseExpIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== -class ElementwiseCloneChannelsLastMemoryFormatModule(torch.nn.Module): +class ElementwiseExpm1Module(torch.nn.Module): def __init__(self): super().__init__() @@ -4225,24 +5241,22 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), ] ) - def forward(self, x): - return torch.clone(x, memory_format=torch.channels_last) + def forward(self, a): + return torch.expm1(a) -@register_test_case( - module_factory=lambda: ElementwiseCloneChannelsLastMemoryFormatModule() -) -def ElementwiseCloneChannelsLastMemoryFormatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4, 5)) +@register_test_case(module_factory=lambda: ElementwiseExpm1Module()) +def ElementwiseExpm1Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) # ============================================================================== -class LiftFreshCopyModule(torch.nn.Module): +class ElementwiseExpm1IntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4250,22 +5264,22 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1, -1], torch.float32, True), + ([-1, -1], torch.int32, True), ] ) - def forward(self, x): - return torch.ops.aten.lift_fresh_copy(x) + def forward(self, a): + return torch.expm1(a) -@register_test_case(module_factory=lambda: LiftFreshCopyModule()) -def LiftFreshCopyModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 3, 4)) +@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) +def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== -class ElementwiseExpModule(torch.nn.Module): +class ElementwiseSpecialExpm1Module(torch.nn.Module): def __init__(self): super().__init__() @@ -4277,18 +5291,18 @@ def __init__(self): ] ) def forward(self, a): - return torch.exp(a) + return torch.special.expm1(a) -@register_test_case(module_factory=lambda: ElementwiseExpModule()) -def ElementwiseExpModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module()) +def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) # ============================================================================== -class ElementwiseExpIntModule(torch.nn.Module): +class ElementwiseSpecialExpm1IntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4300,18 +5314,18 @@ def __init__(self): ] ) def forward(self, a): - return torch.exp(a) + return torch.special.expm1(a) -@register_test_case(module_factory=lambda: ElementwiseExpIntModule()) -def ElementwiseExpIntModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule()) +def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) # ============================================================================== -class ElementwiseExpm1Module(torch.nn.Module): +class ElementwiseRad2DegModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4323,18 +5337,18 @@ def __init__(self): ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.ops.aten.rad2deg(a) -@register_test_case(module_factory=lambda: ElementwiseExpm1Module()) -def ElementwiseExpm1Module_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: ElementwiseRad2DegModule()) +def ElementwiseRad2DegModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) # ============================================================================== -class ElementwiseExpm1IntModule(torch.nn.Module): +class ElementwiseRad2DegIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4342,16 +5356,16 @@ def __init__(self): @annotate_args( [ None, - ([-1, -1], torch.int32, True), + ([-1, -1], torch.float32, True), ] ) def forward(self, a): - return torch.special.expm1(a) + return torch.ops.aten.rad2deg(a) -@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) -def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) +@register_test_case(module_factory=lambda: ElementwiseRad2DegIntModule()) +def ElementwiseRad2DegIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) # ============================================================================== @@ -5338,6 +6352,29 @@ def AtenTrilModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenTrilStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 8], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.tril(x) + + +@register_test_case(module_factory=lambda: AtenTrilStaticModule()) +def AtenTrilStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 8)) + + +# ============================================================================== + + class AtenTrilWithPosDiagonalModule(torch.nn.Module): def __init__(self): super().__init__() @@ -5361,6 +6398,29 @@ def AtenTrilWithPosDiagonalModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenTrilWithPosDiagonalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([9, 4, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.tril(x, diagonal=2) + + +@register_test_case(module_factory=lambda: AtenTrilWithPosDiagonalStaticModule()) +def AtenTrilWithPosDiagonalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(9, 4, 3)) + + +# ============================================================================== + + class AtenTrilWithNegDiagonalModule(torch.nn.Module): def __init__(self): super().__init__() @@ -5384,6 +6444,29 @@ def AtenTrilWithNegDiagonalModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenTrilWithNegDiagonalStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 1, 5, 9], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.tril(x, diagonal=-4) + + +@register_test_case(module_factory=lambda: AtenTrilWithNegDiagonalStaticModule()) +def AtenTrilWithNegDiagonalStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 5, 9)) + + +# ============================================================================== + + class AtenRoundFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -6035,3 +7118,165 @@ def forward(self, x): ) def FakeQuantizePerTensorAffineRoundToEvenModule_basic(module, tu: TestUtils): module.forward(torch.FloatTensor([0.5, 1.5, -0.5, -1.5])) + + +# ============================================================================== + + +class TriuIndicesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(4, 3, 1) + + +@register_test_case(module_factory=lambda: TriuIndicesModule()) +def TriuIndicesModule_basic(module, tu: TestUtils): + module.forward() + + +class TriuIndicesAllZerosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(0, 0, 0) + + +@register_test_case(module_factory=lambda: TriuIndicesAllZerosModule()) +def TriuIndicesAllZerosModule_basic(module, tu: TestUtils): + module.forward() + + +class TriuIndicesNegativeOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.triu_indices(5, 16, -2) + + +@register_test_case(module_factory=lambda: TriuIndicesNegativeOffsetModule()) +def TriuIndicesNegativeOffsetModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + +class TrilIndicesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(4, 3, 1) + + +@register_test_case(module_factory=lambda: TrilIndicesModule()) +def TrilIndicesModule_basic(module, tu: TestUtils): + module.forward() + + +class TrilIndicesAllZerosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(0, 0, 0) + + +@register_test_case(module_factory=lambda: TrilIndicesAllZerosModule()) +def TrilIndicesAllZerosModule_basic(module, tu: TestUtils): + module.forward() + + +class TrilIndicesNegativeOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(5, 16, -2) + + +@register_test_case(module_factory=lambda: TrilIndicesNegativeOffsetModule()) +def TrilIndicesNegativeOffsetModule_basic(module, tu: TestUtils): + module.forward() + + +class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.tril_indices(7, 9, 8) + + +@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule()) +def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + +class Deg2radModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.deg2rad(x) + + +@register_test_case(module_factory=lambda: Deg2radModule()) +def Deg2radModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index 7fdfb454d362..304bc422e4d2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -599,6 +599,51 @@ def ElementwiseLtIntTensorModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5, high=10), tu.randint(5, high=10)) +class ElementwiseIntTensorLtFloatTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.int64, True), + ([-1], torch.float64, True), + ] + ) + def forward(self, x, y): + return torch.lt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule()) +def ElementwiseIntTensorLtFloatTensorModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, high=10), tu.rand(5, high=10).to(torch.float64)) + + +class ElementwiseFloatTensorGtIntTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1], torch.int32, True), + ] + ) + def forward(self, x, y): + return torch.gt(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseIntTensorLtFloatTensorModule()) +def ElementwiseFloatTensorGtIntTensorModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(3, 5, high=10).to(torch.float32), + tu.randint(5, high=10, dtype=torch.int32), + ) + + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py new file mode 100644 index 000000000000..9b761003349f --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/linalg_algorithms.py @@ -0,0 +1,93 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class DeterminantModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.det(A) + + +@register_test_case(module_factory=lambda: DeterminantModule()) +def DeterminantModule_F32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.float32) + module.forward(A) + + +class DeterminantBatchedModule(torch.nn.Module): + @export + @annotate_args([None, [(3, 4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.det(A) + + +@register_test_case(module_factory=lambda: DeterminantBatchedModule()) +def DeterminantBatchedModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) + + +class DeterminantDynamicModule(torch.nn.Module): + @export + @annotate_args([None, [(-1, -1, -1), torch.float32, True]]) + def forward(self, A): + return torch.linalg.det(A) + + +@register_test_case(module_factory=lambda: DeterminantBatchedModule()) +def DeterminantDynamicModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) + + +# ============================================================================== + + +class SignAndLogarithmOfDeterminantModule(torch.nn.Module): + @export + @annotate_args([None, [(4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.slogdet(A) + + +@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantModule()) +def SignAndLogarithmOfDeterminantModule_F32(module, tu: TestUtils): + A = tu.rand(4, 4).to(dtype=torch.float32) + module.forward(A) + + +class SignAndLogarithmOfDeterminantBatchedModule(torch.nn.Module): + @export + @annotate_args([None, [(3, 4, 4), torch.float32, True]]) + def forward(self, A): + return torch.linalg.slogdet(A) + + +@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantBatchedModule()) +def SignAndLogarithmOfDeterminantBatchedModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) + + +class SignAndLogarithmOfDeterminantDynamicModule(torch.nn.Module): + @export + @annotate_args([None, [(-1, -1, -1), torch.float32, True]]) + def forward(self, A): + return torch.linalg.slogdet(A) + + +@register_test_case(module_factory=lambda: SignAndLogarithmOfDeterminantBatchedModule()) +def SignAndLogarithmOfDeterminantDynamicModule_F32(module, tu: TestUtils): + A = tu.rand(3, 4, 4).to(dtype=torch.float32) + module.forward(A) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 0093f13ce9e9..2067e13d5997 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -12,6 +12,30 @@ # ============================================================================== +class AtenDotModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.dot(lhs, rhs) + + +@register_test_case(module_factory=lambda: AtenDotModule()) +def AtenDotModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4)) + + +# ============================================================================== + + class MatmulDot(torch.nn.Module): def __init__(self): super().__init__() @@ -180,6 +204,30 @@ def Matmul4dStatic_basic(module, tu: TestUtils): # ============================================================================== +class Matmul4dStaticBroadcast(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([10, 6, 2], torch.float32, True), + ([10, 10, 2, 6], torch.float32, True), + ] + ) + def forward(self, lhs, rhs): + return torch.matmul(lhs, rhs) + + +@register_test_case(module_factory=lambda: Matmul4dStaticBroadcast()) +def Matmul4dStaticBroadcast_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 6, 2), tu.rand(10, 10, 2, 6)) + + +# ============================================================================== + + class MatmulStaticBroadcast(torch.nn.Module): def __init__(self): super().__init__() @@ -313,6 +361,8 @@ def AtenMmIntTypes_basic(module, tu: TestUtils): # ============================================================================== +# For DQ-Q fake quantization ops +import torch.ao.quantization.fx._decomposed class AtenMmQint8(torch.nn.Module): @@ -328,12 +378,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQint8()) @@ -360,12 +412,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.199, 65) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0215, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.199, 65, 0, 255, torch.uint8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0215, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQuint8()) @@ -392,12 +446,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.mm(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.mm(x, y) + return z @register_test_case(module_factory=lambda: AtenMmQMixedSigni8()) @@ -411,6 +467,33 @@ def AtenMmQMixedSigni8_basic(module, tu: TestUtils): # ============================================================================== +class AtenIntMM(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.int8, True), + ([4, 3], torch.int8, True), + ] + ) + def forward(self, x, y): + return torch._int_mm(x, y) + + +@register_test_case(module_factory=lambda: AtenIntMM()) +def AtenIntMM_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, 3, low=-128, high=127).to(torch.int8), + ) + + +# ============================================================================== + + class AtenMatmulQint8VM(torch.nn.Module): def __init__(self): super().__init__() @@ -424,12 +507,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VM()) @@ -454,12 +539,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8VV()) @@ -484,12 +571,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8MV()) @@ -514,12 +603,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.0215, -25) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.0176, 18) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.0215, -25, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.0176, 18, -128, 127, torch.int8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQint8()) @@ -546,12 +637,14 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8()) @@ -578,13 +671,15 @@ def __init__(self): ] ) def forward(self, x, y): - qx = torch._make_per_tensor_quantized_tensor(x, 0.03, -66) - qx = torch.dequantize(qx) - qy = torch._make_per_tensor_quantized_tensor(y, 0.025, 160) - qy = torch.dequantize(qy) - qy = torch.transpose(qy, 1, 2) - qz = torch.matmul(qx, qy) - return qz + x = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + x, 0.03, -66, -128, 127, torch.int8 + ) + y = torch.torch.ops.quantized_decomposed.dequantize_per_tensor.default( + y, 0.025, 160, 0, 255, torch.uint8 + ) + y = torch.transpose(y, 1, 2) + z = torch.matmul(x, y) + return z @register_test_case(module_factory=lambda: AtenMatmulQMixedSigni8Transpose()) @@ -598,6 +693,131 @@ def AtenMatmulQMixedSigni8Transpose_basic(module, tu: TestUtils): # ============================================================================== +class AtenLinear1D(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinear1D()) +def AtenLinear1D_basic(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(3)) + + +# ============================================================================== + + +class AtenLinearMatVec(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ([4], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinearMatVec()) +def AtenLinearMatVec_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(4)) + + +# ============================================================================== + + +class AtenLinearVecMat(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([3, 4], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinearVecMat()) +def AtenLinearVecMat_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(3, 4)) + + +class AtenLinearVecMatBias(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ([3, 4], torch.float32, True), + ([3], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.ops.aten.linear(a, b, c) + + +@register_test_case(module_factory=lambda: AtenLinearVecMatBias()) +def AtenLinearVecMatBias_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(3, 4), tu.rand(3)) + + +# ============================================================================== + + +class AtenLinear2D(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ([5, 4], torch.float32, True), + ] + ) + def forward(self, a, b): + return torch.ops.aten.linear(a, b) + + +@register_test_case(module_factory=lambda: AtenLinear2D()) +def AtenLinear2D_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(5, 4)) + + +# ============================================================================== + + +class AtenLinear3DBias(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([3, 6, 4], torch.float32, True), + ([5, 4], torch.float32, True), + ([5], torch.float32, True), + ] + ) + def forward(self, a, b, c): + return torch.ops.aten.linear(a, b, c) + + +@register_test_case(module_factory=lambda: AtenLinear3DBias()) +def AtenLinear3DBias_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 6, 4), tu.rand(5, 4), tu.rand(5)) + + +# ============================================================================== + + class AtenLinalgCrossInt(torch.nn.Module): @export @annotate_args( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py new file mode 100644 index 000000000000..5cbd50473512 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/meshgrid.py @@ -0,0 +1,88 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class MeshgridIndexingIJ(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.int64, True), + ([4], torch.int64, True), + ([5], torch.int64, True), + ] + ) + def forward(self, x, y, z): + x1, y1, z1 = torch.meshgrid(x, y, z, indexing="ij") + return x1, y1, z1 + + +@register_test_case(module_factory=lambda: MeshgridIndexingIJ()) +def MeshgridIndexingIJ_basic(module, tu: TestUtils): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6, 7]) + z = torch.tensor([8, 9, 10, 11, 12]) + module.forward(x, y, z) + + +class MeshgridIndexingXY(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.int64, True), + ([4], torch.int64, True), + ([5], torch.int64, True), + ] + ) + def forward(self, x, y, z): + x1, y1, z1 = torch.meshgrid(x, y, z, indexing="xy") + return x1, y1, z1 + + +@register_test_case(module_factory=lambda: MeshgridIndexingXY()) +def MeshgridIndexingXY_basic(module, tu: TestUtils): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6, 7]) + z = torch.tensor([8, 9, 10, 11, 12]) + module.forward(x, y, z) + + +class Meshgrid(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.int64, True), + ([4], torch.int64, True), + ] + ) + def forward(self, x, y): + x1, y1 = torch.meshgrid(x, y) + return x1, y1 + + +@register_test_case(module_factory=lambda: Meshgrid()) +def Meshgrid_basic(module, tu: TestUtils): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5, 6, 7]) + module.forward(x, y) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py index 675d04249b90..58c6dfdb90aa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/nll_loss.py @@ -36,6 +36,57 @@ def NllLossModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) +class NllLossStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=0, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule()) +def NllLossStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + + +class NllLossStaticModule_weight(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ([3], torch.float32, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y, z): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=z, reduction=2, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule_weight()) +def NllLossStaticModule_weight_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 3), tu.randint(2, low=0, high=3), torch.tensor([0.3, 0.3, 0.4]) + ) + + class NllLossModule_mean(torch.nn.Module): def __init__(self): super().__init__() @@ -60,6 +111,30 @@ def NllLossModule_mean_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) +class NllLossStaticModule_mean(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=1, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule_mean()) +def NllLossStaticModule_mean_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + + class NllLossModule_sum(torch.nn.Module): def __init__(self): super().__init__() @@ -84,6 +159,30 @@ def NllLossModule_sum_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) +class NllLossStaticModule_sum(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3], torch.float32, True), + ([2], torch.int64, True), + ] + ) + # Here the 2nd index is ignored. + def forward(self, x, y): + return torch.ops.aten.nll_loss_forward( + x, target=y, weight=None, reduction=2, ignore_index=2 + ) + + +@register_test_case(module_factory=lambda: NllLossStaticModule_sum()) +def NllLossStaticModule_sum_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.randint(2, low=0, high=3)) + + class NllLossModule_1D(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index f4c9e39d1790..60c4ee144dfa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -633,3 +633,121 @@ def forward(self, x, w, b): @register_test_case(module_factory=lambda: AtenInstanceNormModule()) def AtenInstanceNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) + + +# ============================================================================== +class RenormModuleFloat32(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2 + self.dim = 1 + self.maxnorm = 10 + + @export + @annotate_args( + [ + None, + ([3, 3], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat32()) +def RenormModuleFloat32_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + + +class RenormModuleFloat16(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2.1 + self.dim = 1 + self.maxnorm = 10 + + @export + @annotate_args( + [ + None, + ([3, 4, 5], torch.float16, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat16()) +def RenormModuleFloat16_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5).to(torch.float16)) + + +class RenormModuleFloat32NegativeDim(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2.3 + self.dim = -1 + self.maxnorm = 5.2 + + @export + @annotate_args( + [ + None, + ([1, 4, 5, 2], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat32NegativeDim()) +def RenormModuleFloat32NegativeDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 5, 2).to(torch.float32)) + + +class RenormModuleFloat32DynamicDims(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = 2 + self.dim = 1 + self.maxnorm = 10 + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.renorm(x, self.p, self.dim, self.maxnorm) + + +@register_test_case(module_factory=lambda: RenormModuleFloat32DynamicDims()) +def RenormModuleFloat32DynamicDims_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 3)) + + +# ============================================================================== +class WeightNormInterfaceModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.dim = 2 + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, v, g): + return torch.ops.aten._weight_norm_interface(v, g, self.dim) + + +@register_test_case(module_factory=lambda: WeightNormInterfaceModule()) +def WeightNormInterfaceModule_basic(module, tu: TestUtils): + g = tu.rand(3, 10, 10) + v = tu.rand(1, 1, 10) + module.forward(g, v) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index a97d7f09eda6..b9c58551d657 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -123,3 +123,164 @@ def forward(self, x): @register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) def ReflectionPad2dModule_Right(module, tu: TestUtils): module.forward(tu.rand(2, 3, 20, 20)) + + +# ============================================================================== + + +class ReflectionPad3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 20, 20, 20, 20], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (10, 10, 10, 10, 10, 10)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModule()) +def ReflectionPad3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, 20, 20, low=-1)) + + +# ============================================================================== + + +class ReflectionPad3dModuleTop(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 3, 4, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 2, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleTop()) +def ReflectionPad3dModuleTop_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 4, 5, 6)) + + +# ============================================================================== + + +class ReflectionPad3dModuleBottom(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 10, 10, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 5, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleBottom()) +def ReflectionPad3dModuleBottom_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 10, 10, 6)) + + +# ============================================================================== + + +class ReflectionPad3dModuleLeft(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (9, 0, 0, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleLeft()) +def ReflectionPad3dModuleLeft_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 10)) + + +# ============================================================================== + + +class ReflectionPad3dModuleRight(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 11, 0, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleRight()) +def ReflectionPad3dModuleRight_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) + + +# ============================================================================== + + +class ReflectionPad3dModuleFront(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 0, 5, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleFront()) +def ReflectionPad3dModuleFront_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) + + +# ============================================================================== + + +class ReflectionPad3dModuleBack(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 0, 0, 7)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleBack()) +def ReflectionPad3dModuleBack_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 69d813c917f0..e2eaa4cfd0fe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -108,6 +108,29 @@ def AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic( module.forward(tu.rand(1, 512, 15, 14)) +class AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.aap2d = torch.nn.AdaptiveAvgPool2d((2, 2)) + + @export + @annotate_args( + [ + None, + ([1, 3, 7, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.aap2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule() +) +def AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 7, 7)) + + class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() @@ -397,6 +420,35 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp2d = torch.nn.MaxPool2d( + kernel_size=6, + stride=6, + padding=3, + dilation=1, + ceil_mode=True, + ) + + @export + @annotate_args( + [ + None, + ([2, 6, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return self.mp2d(x) + + +@register_test_case( + module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule() +) +def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0)) + + # ============================================================================== @@ -933,6 +985,252 @@ def MaxPool2dWithIndicesBackwardDynamic3DModule_basic(module, tu: TestUtils): # ============================================================================== +class MaxPool3dWithIndicesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, + kernel_size=[2, 2, 2], + stride=[1, 1, 1], + padding=[0, 0, 0], + dilation=[1, 1, 1], + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesModule()) +def MaxPool3dWithIndicesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 8, 8, 8, low=0.5, high=1.0)) + + +class MaxPool3dWithIndicesFullSizeKernelModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, kernel_size=[4, 4, 4], stride=1, padding=0, dilation=1 + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesFullSizeKernelModule()) +def MaxPool3dWithIndicesFullSizeKernelModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 4, 4, 4, low=0.5, high=1.0)) + + +class MaxPool3dWithIndicesNonDefaultPaddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, kernel_size=[4, 8, 4], stride=[1, 1, 1], padding=[2, 4, 2], dilation=1 + ) + + +@register_test_case( + module_factory=lambda: MaxPool3dWithIndicesNonDefaultPaddingModule() +) +def MaxPool3dWithIndicesNonDefaultPaddingModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 16, 16, 16, low=-1.5, high=1.0)) + + +class MaxPool3dWithIndicesNonDefaultStrideModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, kernel_size=[4, 4, 4], stride=[1, 2, 1], padding=0, dilation=1 + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesNonDefaultStrideModule()) +def MaxPool3dWithIndicesNonDefaultStrideModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 16, 80, 16, low=0.5, high=2.0)) + + +class MaxPool3dWithIndicesNonDefaultDilationModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, kernel_size=[4, 4, 4], stride=[1, 1, 1], padding=0, dilation=[2, 2, 2] + ) + + +@register_test_case( + module_factory=lambda: MaxPool3dWithIndicesNonDefaultDilationModule() +) +def MaxPool3dWithIndicesNonDefaultDilationModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 16, 80, 16, low=0.5, high=2.0)) + + +class MaxPool3dWithIndicesNonDefaultParamsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, + kernel_size=[8, 4, 8], + stride=[2, 2, 2], + padding=[1, 2, 1], + dilation=[2, 2, 2], + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesNonDefaultParamsModule()) +def MaxPool3dWithIndicesNonDefaultParamsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 16, 80, 16, low=-0.5, high=4.0)) + + +class MaxPool3dWithIndicesAllNegativeValuesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, kernel_size=[4, 8, 4], stride=[1, 1, 1], padding=[2, 4, 2], dilation=1 + ) + + +@register_test_case( + module_factory=lambda: MaxPool3dWithIndicesAllNegativeValuesModule() +) +def MaxPool3dWithIndicesAllNegativeValuesModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 16, 16, 16, low=-4.5, high=-1.0)) + + +class MaxPool3dWithIndicesStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 16, 16, 16], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, kernel_size=[4, 8, 4], stride=[1, 1, 1], padding=[2, 4, 2], dilation=1 + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesStaticModule()) +def MaxPool3dWithIndicesStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 16, 16, 16, low=-4.5, high=-1.0)) + + +class MaxPool3dWithIndicesAllOnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, + kernel_size=[2, 2, 2], + stride=[1, 1, 1], + padding=[0, 0, 0], + dilation=[1, 1, 1], + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesAllOnesModule()) +def MaxPool3dWithIndicesAllOnesModule_basic(module, tu: TestUtils): + module.forward(torch.ones(1, 1, 8, 8, 8)) + + +class MaxPool3dWithIndicesCeilModeTrueModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.max_pool3d_with_indices( + x, + kernel_size=[2, 2, 2], + stride=[1, 1, 1], + padding=[0, 0, 0], + dilation=[1, 1, 1], + ceil_mode=True, + ) + + +@register_test_case(module_factory=lambda: MaxPool3dWithIndicesCeilModeTrueModule()) +def MaxPool3dWithIndicesCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 8, 8, 8, low=0.5, high=1.0)) + + +# ============================================================================== + + class AvgPool2dFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1017,6 +1315,35 @@ def AvgPool2dStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 10, 20, low=-1)) +class AvgPool2dCountIncludePadFalseStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d( + kernel_size=[3, 3], + stride=[1, 1], + padding=[1, 1], + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([32, 384, 25, 25], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool2dCountIncludePadFalseStaticModule()) +def AvgPool2dCountIncludePadFalseStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(32, 384, 25, 25, low=-1)) + + class AvgPool2dDivisorOverrideModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1104,6 +1431,38 @@ def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils): # ============================================================================== +class AvgPool3dStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool3d( + kernel_size=[2, 2, 2], + stride=[2, 2, 2], + padding=[0, 0, 0], + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + ) + + @export + @annotate_args( + [ + None, + ([2, 2, 4, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return self.ap2d(x) + + +@register_test_case(module_factory=lambda: AvgPool3dStaticModule()) +def AvgPool3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 2, 4, 4, 4, low=-1)) + + +# ============================================================================== + + class AvgPool1dFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1453,6 +1812,22 @@ def AdaptiveMaxPool1dStatic_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10)) +class AdaptiveMaxPool1dDimOneStatic(torch.nn.Module): + def __init__(self): + super().__init__() + self.amp1d = torch.nn.AdaptiveMaxPool1d(output_size=(1), return_indices=False) + + @export + @annotate_args([None, ([1, 512, 7], torch.float32, True)]) + def forward(self, x): + return self.amp1d(x) + + +@register_test_case(module_factory=lambda: AdaptiveMaxPool1dDimOneStatic()) +def AdaptiveMaxPool1dDimOneStatic_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7)) + + # AdaptiveMaxPool2d @@ -1637,3 +2012,61 @@ def forward(self, x): @register_test_case(module_factory=lambda: AdaptiveMaxPool3dStaticWithIndices()) def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16, 17)) + + +# ============================================================================== + + +class MaxUnpool3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, 2, 2, 4], torch.float32, True), + ([-1, -1, 2, 2, 4], torch.int64, True), + ] + ) + def forward(self, x, indices): + return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1)) + + +@register_test_case(module_factory=lambda: MaxUnpool3dModule()) +def MaxUnpool3dModule_basic(module, tu: TestUtils): + input = tu.rand(2, 2, 4, 5, 6) + pool = torch.nn.MaxPool3d( + kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True + ) + output, indices = pool(input) + + module.forward(output, indices) + + +# We have a special case for all-zeros padding, test it too. +class MaxUnpool3dModulePad0(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, 2, 2, 3], torch.float32, True), + ([-1, -1, 2, 2, 3], torch.int64, True), + ] + ) + def forward(self, x, indices): + return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0)) + + +@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0()) +def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): + input = tu.rand(2, 2, 4, 5, 6) + pool = torch.nn.MaxPool3d( + kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True + ) + output, indices = pool(input) + + module.forward(output, indices) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py index 5114f78d5ca7..3c9c3073525b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -181,3 +181,28 @@ def get_quantized_mlp(): @register_test_case(module_factory=get_quantized_mlp) def QuantizedMLP_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) + + +# ============================================================================== + + +class FakeQuantizePerTensorAffineCachemaskModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.fake_quantize_per_tensor_affine_cachemask( + a, 2.0, 0, -128, 127 + )[0] + + +@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineCachemaskModule()) +def FakeQuantizePerTensorAffineCachemaskModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 9e0869dd998a..b72bf64dbcfa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -81,6 +81,29 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class PrimsSumFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.prims.sum(a, (0, 1)) + + +@register_test_case(module_factory=lambda: PrimsSumFloatModule()) +def PrimsSumFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + class ReduceProdFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -170,6 +193,26 @@ def ReduceAllFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) +class ReduceAllDimFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.all(a, dim=0) + + +@register_test_case(module_factory=lambda: ReduceAllDimFloatModule()) +def ReduceAllDimFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + # ============================================================================== @@ -239,6 +282,26 @@ def ReduceAnyFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) +class ReduceAnyDimFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.any(a, dim=0) + + +@register_test_case(module_factory=lambda: ReduceAnyDimFloatModule()) +def ReduceAnyDimFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + # ============================================================================== @@ -1184,6 +1247,29 @@ def ReduceAmaxMultiDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAmaxEmptyDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.amax(a, dim=()) + + +@register_test_case(module_factory=lambda: ReduceAmaxEmptyDim()) +def ReduceAmaxEmptyDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceAmaxOutOfOrderDim(torch.nn.Module): def __init__(self): super().__init__() @@ -1230,6 +1316,75 @@ def ReduceAmaxKeepDim_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAminSingleDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.amin(a, 1) + + +@register_test_case(module_factory=lambda: ReduceAminSingleDim()) +def ReduceAminSingleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + +class ReduceAminmaxSingleDim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.aminmax(a, dim=1) + + +@register_test_case(module_factory=lambda: ReduceAminmaxSingleDim()) +def ReduceAminmaxSingleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + +class ReduceAminmaxAllDims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.aminmax(a) + + +@register_test_case(module_factory=lambda: ReduceAminmaxAllDims()) +def ReduceAminmaxAllDims_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5, high=100)) + + +# ============================================================================== + + class ReduceMinFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1421,6 +1576,29 @@ def ArgmaxModule_basic(module, tu: TestUtils): # ============================================================================== +class ArgmaxKeepdimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.argmax(a, keepdim=True) + + +@register_test_case(module_factory=lambda: ArgmaxKeepdimModule()) +def ArgmaxKeepdimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ArgmaxIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2105,6 +2283,78 @@ def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils): # ============================================================================== +class L1LossNoReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ([2, 4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.l1_loss(x, y, reduction=0) + + +@register_test_case(module_factory=lambda: L1LossNoReductionModule()) +def L1LossNoReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +# ============================================================================== + + +class L1LossMeanReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ([2, 4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.l1_loss(x, y, reduction=1) + + +@register_test_case(module_factory=lambda: L1LossMeanReductionModule()) +def L1LossMeanReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +# ============================================================================== + + +class L1LossSumReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ([2, 4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.l1_loss(x, y, reduction=2) + + +@register_test_case(module_factory=lambda: L1LossSumReductionModule()) +def L1LossSumReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +# ============================================================================== + + class CrossEntropyLossModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2159,6 +2409,29 @@ def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) +class BinaryCrossEntropyWithLogitsStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([8, 2], torch.float32, True), + ([8, 2], torch.float32, True), + ] + ) + def forward(self, input, target): + return torch.ops.aten.binary_cross_entropy_with_logits( + input, target, reduction=0 + ) + + +@register_test_case(module_factory=lambda: BinaryCrossEntropyWithLogitsStaticModule()) +def BinaryCrossEntropyWithLogitsStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(8, 2), tu.rand(8, 2)) + + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index 7b569529bc1a..d1ddc42b39b1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1174,6 +1174,30 @@ def ReshapeDynamicModule_basic(module, tu: TestUtils): # ============================================================================== +class ViewDtypeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([12, 1], torch.float32, True), + ] + ) + def forward(self, a): + res = a.view(torch.int8) + return res + + +@register_test_case(module_factory=lambda: ViewDtypeStaticModule()) +def ViewDtypeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(12, 1)) + + +# ============================================================================== + + class ReshapeAliasCollapseModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1303,6 +1327,27 @@ def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6)) +class EinsumStaticDiagonalDimensionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 5, 4, 4], torch.float32, True), + ([5, 4, 5, 4], torch.float32, True), + ] + ) + def forward(self, tensor1, tensor2): + return torch.ops.aten.einsum("iijj,ijij->ji", [tensor1, tensor2]) + + +@register_test_case(module_factory=lambda: EinsumStaticDiagonalDimensionModule()) +def EinsumStaticDiagonalDimensionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5, 4, 4), tu.rand(5, 4, 5, 4)) + + class EinsumStaticContractRhsModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1367,3 +1412,536 @@ def forward(self, tensor1, tensor2): ) def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5)) + + +class InterpolateModule(torch.nn.Module): + def __init__( + self, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + antialias=False, + ): + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + self.antialias = antialias + super().__init__() + + def _forward(self, input): + return torch.nn.functional.interpolate( + input, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=self.recompute_scale_factor, + antialias=self.antialias, + ) + + +class InterpolateStaticModule(InterpolateModule): + @export + @annotate_args( + [ + None, + ([1, 1, 4, 5], torch.float32, True), + ] + ) + def forward(self, input): + return self._forward(input) + + +class InterpolateDynamicModule(InterpolateModule): + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input): + return self._forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateStaticModule( + scale_factor=0.41, mode="bilinear", align_corners=True + ) +) +def InterpolateStaticModule_scales_bilinear_align_corners(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="nearest") +) +def InterpolateDynamicModule_sizes_nearest(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="bilinear") +) +def InterpolateDynamicModule_sizes_bilinear(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +@register_test_case( + module_factory=lambda: InterpolateDynamicModule( + scale_factor=(1.9, 2.4), mode="bilinear", recompute_scale_factor=True + ) +) +def InterpolateDynamicModule_scales_recompute_bilinear(module, tu: TestUtils): + input = torch.arange(20).to(dtype=torch.float32) + input = input.reshape((1, 1, 4, 5)) + module.forward(input) + + +# ============================================================================== + + +class Atleast1dModule0dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.atleast_1d(x) + + +@register_test_case(module_factory=lambda: Atleast1dModule0dInput()) +def Atleast1dModule0dInput_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Atleast1dModule1dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.atleast_1d(x) + + +@register_test_case(module_factory=lambda: Atleast1dModule1dInput()) +def Atleast1dModule1dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(4)) + + +# ============================================================================== + + +class Rot90BasicModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([4, 5], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.rot90( + a, + k=1, + dims=( + 0, + 1, + ), + ) + + +@register_test_case(module_factory=lambda: Rot90BasicModule()) +def Rot90BasicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5)) + + +class Rot90DynamicDimsModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.rot90( + a, + k=1, + dims=( + 0, + 1, + ), + ) + + +@register_test_case(module_factory=lambda: Rot90DynamicDimsModule()) +def Rot90DynamicDimsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 2, 4)) + + +class Rot90MultipleRotationsModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([7, 4, 6], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.rot90( + a, + k=6, + dims=( + 1, + 2, + ), + ) + + +@register_test_case(module_factory=lambda: Rot90MultipleRotationsModule()) +def Rot90MultipleRotationsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(7, 4, 6)) + + +class Rot90NegativeOddRotationsModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([7, 4, 6, 5, 3], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.rot90( + a, + k=-5, + dims=( + 1, + 2, + ), + ) + + +@register_test_case(module_factory=lambda: Rot90NegativeOddRotationsModule()) +def Rot90NegativeOddRotationsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(7, 4, 6, 5, 3)) + + +class Rot90NegativeEvenRotationsModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([6, 5, 1, 7, 3], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.rot90( + a, + k=-6, + dims=( + 1, + -2, + ), + ) + + +@register_test_case(module_factory=lambda: Rot90NegativeEvenRotationsModule()) +def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 5, 1, 7, 3)) + + +# ============================================================================== + + +class Unfold_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(0, 2, 2) + + +@register_test_case(module_factory=lambda: Unfold_Module()) +def Unfold_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4)) + + +class Unfold_Module_Negative_Dim(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4, 4, 4], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(-1, 2, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Negative_Dim()) +def Unfold_Module_Rank_4(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 4, 4)) + + +class Unfold_Module_Rank_Zero(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(0, 1, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) +def Unfold_Module_Rank_Zero_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Unfold_Module_Rank_Zero_Size_Zero(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(0, 0, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero()) +def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils): + module.forward(tu.rand()) + + +class Unfold_Module_Dynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return x.unfold(1, 2, 1) + + +@register_test_case(module_factory=lambda: Unfold_Module_Dynamic()) +def Unfold_Module_Dynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 4, 4)) + + +# ============================================================================== + + +class Aten_TrilinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 3, 3], torch.float32, True), + ([3, 3, 3], torch.float32, True), + ([3, 3, 3], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, i2, i3, expand1=[], expand2=[], expand3=[], sumdim=[], unroll_dim=0 + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModule()) +def Aten_TrilinearModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3, 3), tu.rand(3, 3, 3), tu.rand(3, 3, 3)) + + +class Aten_TrilinearModuleSumdims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, i2, i3, expand1=[1], expand2=[], expand3=[], sumdim=[0, 2], unroll_dim=0 + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumdims()) +def Aten_TrilinearModuleSumdims_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6)) + + +class Aten_TrilinearModuleSumAllDims(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, + i2, + i3, + expand1=[1], + expand2=[], + expand3=[], + sumdim=[0, 1, 2], + unroll_dim=0, + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumAllDims()) +def Aten_TrilinearModuleSumAllDims_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6)) + + +class Aten_TrilinearModuleVaryingRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, + i2, + i3, + expand1=[1], + expand2=[], + expand3=[0, 1], + sumdim=[0], + unroll_dim=0, + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleVaryingRanks()) +def Aten_TrilinearModuleVaryingRanks_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6)) + + +class Aten_TrilinearModuleVaryingRanksUnorderedExpands(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, + i2, + i3, + expand1=[1], + expand2=[], + expand3=[1, 0], + sumdim=[2, 0], + unroll_dim=0, + ) + + +@register_test_case( + module_factory=lambda: Aten_TrilinearModuleVaryingRanksUnorderedExpands() +) +def Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6)) + + +class Aten_TrilinearModuleZerodDimBug(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ([2, 3, 6], torch.float32, True), + ] + ) + def forward(self, i1, i2, i3): + return torch.ops.aten._trilinear( + i1, i2, i3, expand1=[0], expand2=[0], expand3=[0], sumdim=[2], unroll_dim=0 + ) + + +@register_test_case(module_factory=lambda: Aten_TrilinearModuleZerodDimBug()) +def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils): + return module.forward(tu.rand(2, 3, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index b2d41a422682..cce93d9d64b8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -196,7 +196,7 @@ def forward(self, x): @register_test_case(module_factory=lambda: ExponentialModule()) def ExponentialModule_basic(module, tu: TestUtils): - module.forward(tu.rand(512, 512, 16).double()) + module.forward(tu.rand(1024, 1024, 16).double()) # ============================================================================== @@ -377,6 +377,80 @@ def BernoulliPModule_basic(module, tu: TestUtils): # ============================================================================== +def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils): + assert len(sizes) == 1 or len(sizes) == 2 + init = tu.rand(*sizes).to(dtype=torchdtype).abs() + normalized = init / (init.sum(-1, True, dtype=torchdtype)) + return normalized + + +class MultinomialBase(torch.nn.Module): + def _forward(self, x): + a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True) + return a + + +class MultinomialModule(MultinomialBase): + @export + @annotate_args( + [ + None, + ([-1], torch.float64, True), + ] + ) + def forward(self, x): + return self._forward(x).mean(dtype=torch.double) + + +@register_test_case(module_factory=lambda: MultinomialModule()) +def MultinomialModule_basic(module, tu: TestUtils): + x = generate_sample_distr([100], torch.float64, tu) + module.forward(x) + + +class MultinomialModule2DF32(MultinomialBase): + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + # note: this should really call mean(-1) + # for some reason, doing this causes a torchscript numerics error? + return self._forward(x).mean(dtype=torch.double) + + +@register_test_case(module_factory=lambda: MultinomialModule2DF32()) +def MultinomialModule2D_F32(module, tu: TestUtils): + x = generate_sample_distr([10, 100], torch.float32, tu) + module.forward(x) + + +class MultinomialModule2D(MultinomialBase): + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float64, True), + ] + ) + def forward(self, x): + # note: this should really call mean(-1) + # for some reason, doing this causes a torchscript numerics error? + return self._forward(x).mean(dtype=torch.double) + + +@register_test_case(module_factory=lambda: MultinomialModule2D()) +def MultinomialModule2D_basic(module, tu: TestUtils): + x = generate_sample_distr([10, 100], torch.float64, tu) + module.forward(x) + + +# ============================================================================== + + class RandLikeModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 5576e850a9a6..3157a0fdee4f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -36,6 +36,30 @@ def AddIntModule_basic(module, tu: TestUtils): # ============================================================================== +class AddFloatIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([], torch.float32, True), + ([], torch.int64, True), + ] + ) + def forward(self, lhs, rhs): + return float(lhs) + int(rhs) + + +@register_test_case(module_factory=lambda: AddFloatIntModule()) +def AddFloatIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(), tu.randint(low=-100, high=100)) + + +# ============================================================================== + + class SubIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -518,7 +542,7 @@ def __init__(self): @annotate_args( [ None, - ([], torch.float, True), + ([1], torch.float, True), ] ) def forward(self, val): @@ -528,3 +552,21 @@ def forward(self, val): @register_test_case(module_factory=lambda: AtenItemFpOpModule()) def AtenItemFpOpModule_basic(module, tu: TestUtils): module.forward(tu.rand(1)) + + +# ============================================================================== + + +class TrueFalseOrBoolOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([], torch.bool, True), ([], torch.bool, True)]) + def forward(self, a, b): + return a | b + + +@register_test_case(module_factory=lambda: TrueFalseOrBoolOpModule()) +def TrueFalseOrBoolOpModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=0, high=1).bool(), tu.randint(low=1, high=2).bool()) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index cc4970573f07..13ef6a818e82 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -12,6 +12,31 @@ # ============================================================================== +class MaskedScatterStaticBasic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 4], torch.float32, True), + ([4, 4], torch.bool, True), + ([8, 8], torch.float32, True), + ] + ) + def forward(self, x, mask, y): + return torch.masked_scatter(x, mask, y) + + +@register_test_case(module_factory=lambda: MaskedScatterStaticBasic()) +def MaskedScatterStaticBasic_basic(module, tu: TestUtils): + x = torch.rand(4, 4) + mask = torch.rand(4, 4) > 0.5 + y = torch.rand(8, 8) + module.forward(x, mask, y) + + class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -110,6 +135,39 @@ def IndexPutImpl2DNoneIndexStaticModule_basic(module, tu: TestUtils): ) +# ============================================================================== + + +class IndexPutImpl2DNoneIndexBroadcastStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 4], torch.int64, True), + ([3], torch.int64, True), + ([], torch.int64, True), + ] + ) + def forward(self, input, index, value): + return torch.ops.aten._index_put_impl_( + input, (None, index), value, accumulate=False, unsafe=False + ) + + +@register_test_case( + module_factory=lambda: IndexPutImpl2DNoneIndexBroadcastStaticModule() +) +def IndexPutImpl2DNoneIndexBroadcastStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 4, high=3), tu.randint(3, high=3), torch.tensor(0)) + + +# ============================================================================== + + class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() @@ -995,6 +1053,56 @@ def ScatterValueIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ScatterAddStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([10, 8, 6], torch.float32, True), + ([2, 4, 3], torch.int64, True), + ([5, 8, 6], torch.float32, True), + ] + ) + def forward(self, input, index, src): + return torch.ops.aten.scatter_add(input, 0, index, src) + + +@register_test_case(module_factory=lambda: ScatterAddStaticModule()) +def ScatterAddStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +# ============================================================================== + + +class ScatterAddDynamicModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.int64, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, input, index, src): + return torch.ops.aten.scatter_add(input, 0, index, src) + + +@register_test_case(module_factory=lambda: ScatterAddDynamicModule()) +def ScatterAddDynamicModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6)) + + +# ============================================================================== + + class ScatterReduceFloatModule(torch.nn.Module): include_self: bool reduce_type: str @@ -1244,3 +1352,36 @@ def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): tu.randint(7, high=5), tu.rand(2, 3, 6, 7), ) + + +# ============================================================================== + + +class IndexPutWithNoneAndBroadcastModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 4, 5], torch.float32, True), + ([6, 1], torch.int64, True), + ([7], torch.int64, True), + ([1, 6, 7], torch.float32, True), + ] + ) + def forward(self, input, index1, index2, value): + return torch.ops.aten.index_put( + input, (None, None, index1, index2), value, accumulate=True + ) + + +@register_test_case(module_factory=lambda: IndexPutWithNoneAndBroadcastModule()) +def IndexPutWithNoneAndBroadcastModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(2, 3, 4, 5), + tu.randint(6, 1, high=4), + tu.randint(7, high=5), + tu.rand(1, 6, 7), # broadcasted to (2, 3, 6, 7) + ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index be2a80d84427..2af7fd440bdf 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -58,6 +58,29 @@ def SliceStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceStaticComplexInputModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4, 7], torch.complex64, True), + ] + ) + def forward(self, x): + return x[0:5:1, 1:3:1, 2:4:1] + + +@register_test_case(module_factory=lambda: SliceStaticComplexInputModule()) +def SliceStaticComplexInputModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 7).to(torch.complex64)) + + +# ============================================================================== + + class SliceOutOfUpperBoundIndexModule(torch.nn.Module): def __init__(self): super().__init__() @@ -156,6 +179,29 @@ def SliceOutOfLowerBoundStartIndexModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceOutOfLowerBoundStartIndexStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4, 7], torch.float32, True), + ] + ) + def forward(self, x): + return x[-8:3:1, :, :] + + +@register_test_case(module_factory=lambda: SliceOutOfLowerBoundStartIndexStaticModule()) +def SliceOutOfLowerBoundStartIndexStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 7)) + + +# ============================================================================== + + class SliceEndSleStartModule(torch.nn.Module): def __init__(self): super().__init__() @@ -182,6 +228,32 @@ def SliceEndSleStartModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceEndSleStartStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([6, 4, 7], torch.float32, True), + ] + ) + def forward(self, x): + # TODO: remove hacky cat tensor once refbackend supports 0 size dim + result = x[:, 4:3, :] + cat_tensor = torch.ones((6, 1, 7), dtype=torch.float32) + return torch.cat((result, cat_tensor), dim=1) + + +@register_test_case(module_factory=lambda: SliceEndSleStartStaticModule()) +def SliceEndSleStartStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 4, 7)) + + +# ============================================================================== + + class SliceStartEqEndModule(torch.nn.Module): def __init__(self): super().__init__() @@ -208,6 +280,29 @@ def SliceStartEqEndModule_basic(module, tu: TestUtils): # ============================================================================== +class SliceSizeTwoStepDivisibleStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([10, 6, 16], torch.float32, True), + ] + ) + def forward(self, x): + return x[0:5:2, 0:3:2, 0:4:2] + + +@register_test_case(module_factory=lambda: SliceSizeTwoStepDivisibleStaticModule()) +def SliceSizeTwoStepDivisibleStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 6, 16)) + + +# ============================================================================== + + class SliceSizeTwoStepModule(torch.nn.Module): def __init__(self): super().__init__() @@ -696,6 +791,33 @@ def SliceCopyNegative_Module_basic(module, tu: TestUtils): # ============================================================================== +class SliceCopyMax_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + # A slice without specified end uses the max. value of int64_t + xslice = torch.ops.aten.slice(x, 0, 0, 9223372036854775807, 1) + xslice.copy_(y) + return x + + +@register_test_case(module_factory=lambda: SliceCopyMax_Module()) +def SliceCopyMax_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 4, 4), tu.rand(4, 4, 4)) + + +# ============================================================================== + + class SliceCopyStartGreaterThanDimSize_Module(torch.nn.Module): def __init__(self): super().__init__() @@ -850,6 +972,71 @@ def UnbindIntGetItem_Module_basic(module, tu: TestUtils): # ============================================================================== +class TensorsSplitTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([6, 10, 12], torch.float32, True)]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 2, dim=0) + return s1 + + +@register_test_case(module_factory=lambda: TensorsSplitTensorModule()) +def TensorsSplitTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(6, 10, 12)) + + +# ============================================================================== + + +class TensorsSplitTensorLastSmallerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([8, 10, 12], torch.float32, True)]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 3, dim=0) + return s2 + + +@register_test_case(module_factory=lambda: TensorsSplitTensorLastSmallerModule()) +def TensorsSplitTensorLastSmallerModule_basic(module, tu: TestUtils): + # Splitting the first dimension with 8 elements into chunks of 3 + # will leave the last result to have 2 elements in that dimension. + module.forward(tu.rand(8, 10, 12)) + + +# ============================================================================== + + +class TensorsSplitTensorNegativeDimModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([10, 12, 6], torch.float32, True)]) + def forward(self, x): + s0, s1, s2 = torch.ops.aten.split(x, 2, -1) + return s1 + + +@register_test_case(module_factory=lambda: TensorsSplitTensorNegativeDimModule()) +def TensorsSplitTensorNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 12, 6)) + + +# ============================================================================== + + +# ============================================================================== + + class SplitTensorGetItem_Module(torch.nn.Module): def __init__(self): super().__init__() @@ -995,8 +1182,8 @@ def __init__(self): ] ) def forward(self, x): - chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1) - return torch.ops.aten.add(chunk_0, chunk_1), chunk_2 + a0, a1, a2, a3, a4 = torch.chunk(x, 6, 1) + return a0, a1, a2, a3, a4 @register_test_case(module_factory=lambda: ChunkListUnpackUneven_Module()) @@ -1076,3 +1263,48 @@ def forward(self, x): @register_test_case(module_factory=lambda: SplitWithSizes_Module()) def SplitWithSizes_Module_basic(module, tu: TestUtils): module.forward(tu.rand(5, 2, 2)) + + +# ============================================================================== + + +class TensorSplitSections_GetItemModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) + def forward(self, x): + split = torch.tensor_split(x, 3, dim=1) + return split[0], split[1], split[2] + + +@register_test_case(module_factory=lambda: TensorSplitSections_GetItemModule()) +def TensorSplitSections_GetItemModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5)) + + +class TensorSplitSections_ListUnpackModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 5], torch.float32, True), + ] + ) + def forward(self, x): + a, b, c, d = torch.tensor_split(x, 4, dim=1) + return a, b, c, d + + +@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule()) +def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 5)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py new file mode 100644 index 000000000000..57a7270f9d09 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py @@ -0,0 +1,93 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class AtenHannWindowPeriodicFalseModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.hann_window(20, False) + + +@register_test_case(module_factory=lambda: AtenHannWindowPeriodicFalseModule()) +def AtenHannWindowPeriodicFalseModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + +class AtenHannWindowPeriodicTrueModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return torch.ops.aten.hann_window(20, True) + + +@register_test_case(module_factory=lambda: AtenHannWindowPeriodicTrueModule()) +def AtenHannWindowPeriodicTrueModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + +class AtenFftRfft2DLastDim(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([16, 9], torch.float32, True), + ] + ) + def forward(self, input): + return torch.fft.rfft(input, dim=-1) + + +@register_test_case(module_factory=lambda: AtenFftRfft2DLastDim()) +def AtenFftRfft2DLastDim_basic(module, tu: TestUtils): + module.forward(tu.rand(16, 9)) + + +# ============================================================================== + + +class AtenFftRfft2DMiddleDim(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([36, 10], torch.float32, True), + ] + ) + def forward(self, input): + return torch.fft.rfft(input, dim=0) + + +@register_test_case(module_factory=lambda: AtenFftRfft2DMiddleDim()) +def AtenFftRfft2DMiddleDim_basic(module, tu: TestUtils): + module.forward(tu.rand(36, 10)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py new file mode 100644 index 000000000000..387ff6cfc8de --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/timeout.py @@ -0,0 +1,47 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + + +# ============================================================================== +class TimeOutModule(torch.nn.Module): + """ + This test ensures that the timeout mechanism works as expected. + + The module runs an infinite loop that will never terminate, + and the test is expected to time out and get terminated + """ + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + """ + Run an infinite loop. + + This may loop in the compiler or the runtime depending on whether + fx or torchscript is used. + """ + # input_arg_2 is going to be 2 + # but we can't just specify it as a + # constant because the compiler will + # attempt to get rid of the whole loop + input_arg_2 = x.size(0) + sum = 100 + while input_arg_2 < sum: # sum will always > 2 + sum += 1 + return sum + + +@register_test_case(module_factory=lambda: TimeOutModule(), timeout_seconds=10) +def TimeOutModule_basic(module, tu: TestUtils): + module.forward(torch.ones((42, 42))) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 5d3d085d5e2b..f8deda462905 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -136,6 +136,26 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils): module.forward(tensor) +class TypeConversionUint8ToF32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.uint8, True), + ] + ) + def forward(self, x): + return x.to(torch.float) + + +@register_test_case(module_factory=lambda: TypeConversionUint8ToF32Module()) +def TypeConversionUint8ToF32Module_basic(module, tu: TestUtils): + module.forward(torch.tensor([0, 1, 255]).to(torch.uint8)) + + # ============================================================================== @@ -235,6 +255,45 @@ def ToDtypeBoolLayoutNoneStaticModule_basic(module, tu: TestUtils): module.forward(tu.randint(3, 5)) +class ToDtypeFloatFromIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.to( + x, + dtype=torch.float32, + ) + + +@register_test_case(module_factory=lambda: ToDtypeFloatFromIntModule()) +def ToDtypeFloatFromIntModule_basic(module, tu: TestUtils): + input = torch.randint(low=-5, high=5, size=(2, 2)).to(torch.int64) + module.forward(input) + + +class ToDtypeIntFromFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float64, True)]) + def forward(self, x): + return torch.ops.aten.to( + x, + dtype=torch.int64, + ) + + +@register_test_case(module_factory=lambda: ToDtypeIntFromFloatModule()) +def ToDtypeIntFromFloatModule_basic(module, tu: TestUtils): + input = tu.rand(2, 2, low=-5, high=5) + input[1][1] = tu.randint(1, 1) + 0.7 + module.forward(input) + + class TypeAsSameModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/utils.py b/projects/pt1/python/torch_mlir_e2e_test/utils.py index dd9f8d8f8170..0ab47efa9284 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/utils.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from torch_mlir.torchscript import TensorPlaceholder +from torch_mlir.compiler_utils import TensorPlaceholder from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME diff --git a/projects/pt1/test/lit.cfg.py b/projects/pt1/test/lit.cfg.py index 2f2cfe656eae..938b05b53977 100644 --- a/projects/pt1/test/lit.cfg.py +++ b/projects/pt1/test/lit.cfg.py @@ -18,6 +18,24 @@ # Configuration file for the 'lit' test runner. + +# Find path to the ASan runtime required for the Python interpreter. +def find_asan_runtime(): + if not "asan" in config.available_features or not "Linux" in config.host_os: + return "" + # Find the asan rt lib + return ( + subprocess.check_output( + [ + config.host_cxx.strip(), + f"-print-file-name=libclang_rt.asan-{config.host_arch}.so", + ] + ) + .decode("utf-8") + .strip() + ) + + # name: The name of this test suite. config.name = "TORCH_MLIR_PT1" @@ -66,10 +84,15 @@ "PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True ) +# Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux. +# TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms). +if "asan" in config.available_features and "Linux" in config.host_os: + _asan_rt = find_asan_runtime() + config.python_executable = f"env LD_PRELOAD={_asan_rt} {config.python_executable}" # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in # llvm/utils/lit/lit/llvm/config.py. -if "Windows" in config.host_os: +elif "Windows" in config.host_os: config.python_executable = '"%s"' % (config.python_executable) tool_dirs = [ diff --git a/projects/pt1/test/lit.site.cfg.py.in b/projects/pt1/test/lit.site.cfg.py.in index 3b3ef59bd7aa..6f277e1a67ac 100644 --- a/projects/pt1/test/lit.site.cfg.py.in +++ b/projects/pt1/test/lit.site.cfg.py.in @@ -6,6 +6,8 @@ config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@" config.torch_mlir_python_packages_dir = "@TORCH_MLIR_PYTHON_PACKAGES_DIR@" config.host_os = "@HOST_OS@" +config.host_cxx = "@HOST_CXX@" +config.host_arch = "@HOST_ARCH@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 76cdbcca41eb..2ab12d3dd6fb 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -53,6 +53,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Tools ADD_TO_PARENT TorchMLIRPythonSources SOURCES tools/import_onnx/__main__.py + tools/opt/__main__.py ) declare_mlir_python_sources(TorchMLIRSiteInitialize @@ -96,12 +97,18 @@ set(_source_components TorchMLIRPythonSources TorchMLIRPythonExtensions TorchMLIRSiteInitialize - - # Sources related to optional Torch extension dependent features. Typically - # empty unless if project features are enabled. - TorchMLIRPythonTorchExtensionsSources ) +if(TORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND _source_components StablehloPythonExtensions) +endif() + +# Sources related to optional Torch extension dependent features. Typically +# empty unless if project features are enabled. +if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS) + list(APPEND _source_components TorchMLIRPythonTorchExtensionsSources) +endif() + add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI INSTALL_COMPONENT TorchMLIRPythonModules INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs @@ -116,4 +123,6 @@ add_mlir_python_modules(TorchMLIRPythonModules DECLARED_SOURCES ${_source_components} COMMON_CAPI_LINK_LIBS TorchMLIRAggregateCAPI - ) +) + +add_dependencies(TorchMLIRPythonModules torch-mlir-opt) diff --git a/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp index 73abf5cd5577..36e391867533 100644 --- a/python/TorchMLIRModule.cpp +++ b/python/TorchMLIRModule.cpp @@ -28,4 +28,8 @@ PYBIND11_MODULE(_torchMlir, m) { } }, py::arg("context"), py::arg("load") = true); + + m.def("get_int64_max", []() { return INT64_MAX; }); + + m.def("get_int64_min", []() { return INT64_MIN; }); } diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py index 4e5a2f8f8c07..cf07526efceb 100644 --- a/python/torch_mlir/compiler_utils.py +++ b/python/torch_mlir/compiler_utils.py @@ -7,10 +7,59 @@ import os import sys import tempfile -from typing import Union +from typing import Union, List -from torch_mlir.passmanager import PassManager -from torch_mlir.ir import StringAttr +import torch +from .passmanager import PassManager +from .ir import StringAttr + + +class TensorPlaceholder: + """A class that represents a formal parameter of a given shape and dtype. + + This class can be constructed explicitly from a shape and dtype: + ```python + placeholder = TensorPlaceholder([3, 4], torch.float32) + ``` + + This class can also be constructed from a `torch.Tensor` which is already + known to be a valid input to the function. In this case, a set of + dynamic axes are allowed to be specified. + ```python + placeholder = TensorPlaceholder.like(torch.ones(3, 4), dynamic_axes=[1]) + # Equivalent to `TensorPlaceholder([3, -1], torch.float32)` + ``` + """ + + def __init__(self, shape: List[int], dtype: torch.dtype): + """Create a tensor with shape `shape` and dtype `dtype`. + + Args: + shape: The shape of the tensor. A size of `-1` indicates that the + dimension has an unknown size. + dtype: The dtype of the tensor. + """ + self.shape = shape + self.dtype = dtype + + @staticmethod + def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): + """Create a tensor placeholder that is like the given tensor. + + Args: + tensor: The tensor to create a placeholder for. + dynamic_axes: A list of dynamic axes. If specified, the compiled + module will allow those axes to be any size at runtime. + """ + if dynamic_axes is None: + dynamic_axes = [] + shape = [] + for i, dim in enumerate(tensor.shape): + if i in dynamic_axes: + shape.append(-1) + else: + shape.append(dim) + return TensorPlaceholder(shape, tensor.dtype) def get_module_name_for_debug_dump(module): @@ -40,6 +89,9 @@ def run_pipeline_with_repro_report( ) # Lower module in place to make it ready for compiler backends. with module.context as ctx: + # TODO(#3506): Passes can emit errors but not signal failure, + # which causes a native assert. + ctx.emit_error_diagnostics = True pm = PassManager.parse(pipeline) if enable_ir_printing: ctx.enable_multithreading(False) @@ -79,12 +131,12 @@ def run_pipeline_with_repro_report( class OutputType(Enum): - # Output torch dialect. When converting from FX, this will be immediately - # after the import from FX to MLIR. When converting from torchscript, - # this will come after some cleanup passes which attempt to de-alias, - # decompose and infer shapes. These should be roughly the same level of - # abstraction since those steps are done within PyTorch itself - # when coming directly from Dynamo/FX. + # Output torch dialect in backend form. When converting from TorchDynamo, + # this comes after some decomposition and reduce op variants passes are + # applied to the raw torch dialect. When converting from TorchScript, this + # comes after some cleanup passes which attempt to de-alias, decompose and infer shapes. + # These should be roughly the same level of abstraction since those + # steps are done within PyTorch itself when coming directly from Dynamo/FX. TORCH = "torch" # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and @@ -101,7 +153,8 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it to StableHLO. STABLEHLO = "stablehlo" - # Raw output of the JIT IR importer. This is not expected to be useful + # Raw output of the JIT IR importer in the TorchScript frontend or that of + # the FX IR importer in the TorchDynamo frontend. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. RAW = "raw" diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py index 7a6f67b2254f..dec11d5c2b37 100644 --- a/python/torch_mlir/extras/fx_decomp_util.py +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -47,7 +47,15 @@ torch.ops.aten.linspace.default, torch.ops.aten.triu.default, torch.ops.aten.nan_to_num.default, + torch.ops.aten.unbind, + torch.ops.aten.diag, + torch.ops.aten.cumsum, + torch.ops.aten.index_select, ] +if hasattr(torch.ops.aten, "_scaled_dot_product_flash_attention_for_cpu"): + DEFAULT_DECOMPOSITIONS.append( + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu + ) def get_decomposition_table(): diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 381f8f9ad88f..8840055744e7 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -14,6 +14,8 @@ import logging import operator import re +import sympy +import math from dataclasses import dataclass from types import BuiltinMethodType, BuiltinFunctionType from typing import ( @@ -76,11 +78,29 @@ # conditional. ml_dtypes = None +try: + from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity +except ModuleNotFoundError: + # This commit on PyTorch repo introduced IntInfinity and NegativeIntInfinity: + # https://github.com/pytorch/pytorch/commit/2229884102ac95c9dda0aeadbded1b04295d892e + # Required module may not be present in the stable version of PyTorch. + int_oo = None + IntInfinity = None + NegativeIntInfinity = None + from torch.fx.node import ( Argument as NodeArgument, ) from ..ir import ( + AffineAddExpr, + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineMapAttr, + AffineModExpr, + AffineMulExpr, + AffineSymbolExpr, Attribute, Block, Context, @@ -89,6 +109,10 @@ FloatAttr, BF16Type, ComplexType, + Float8E5M2Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E4M3FNUZType, F16Type, F32Type, F64Type, @@ -111,6 +135,7 @@ func as func_dialect, ) + __all__ = [ "FxImporter", ] @@ -138,6 +163,16 @@ torch.complex64: "complex", torch.complex128: "complex", } +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM = { + "float8_e5m2": "f8E5M2", + "float8_e4m3fn": "f8E4M3FN", + "float8_e5m2fnuz": "f8E5M2FNUZ", + "float8_e4m3fnuz": "f8E4M3FNUZ", +} +for dtype_str, dtype_asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, dtype_str)] = dtype_asm TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { torch.float16: lambda: F16Type.get(), @@ -156,6 +191,16 @@ torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), } +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE = { + "float8_e5m2": lambda: Float8E5M2Type.get(), + "float8_e4m3fn": lambda: Float8E4M3FNType.get(), + "float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(), + "float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(), +} +for dtype_str, mlir_type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, dtype_str)] = mlir_type TORCH_DTYPE_TO_NPY_TYPE = { # torch.qint8: None, # no equivalent np datatype @@ -194,6 +239,16 @@ # torch.qint32 14 torch.bfloat16: 15, } +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_INT = { + "float8_e5m2": 23, + "float8_e4m3fn": 24, + "float8_e5m2fnuz": 25, + "float8_e4m3fnuz": 26, +} +for dtype_str, dtype_int in OPTIONAL_TORCH_DTYPE_TO_INT.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_INT[getattr(torch, dtype_str)] = dtype_int TORCH_MEMORY_FORMAT_TO_INT = { torch.contiguous_format: 0, @@ -221,6 +276,9 @@ "ge": torch.ops.aten.ge, "ne": torch.ops.aten.ne, "gt": torch.ops.aten.gt, + "mod": torch.ops.aten.fmod, + "eq": torch.ops.aten.eq, + "floordiv": torch.ops.aten.floordiv, } # torch with cuda has a __version__ that looks like "2.1.0+cu113", @@ -258,63 +316,112 @@ SYMBOLIC_TORCH_OPS = {key for key in SYMBOLIC_OP_TO_TORCH_OP} -@dataclass(frozen=True) -class SparsityMeta: - """ - Class for keeping track of sparsity meta data. +@dataclass +class RangeConstraint: + min_val: int + max_val: int - NOTE: this will be fully replaced by - torch.fx.passes.shape_prop.SparseTensorMetadata - """ - layout: torch.layout - batch_dim: int - sparse_dim: int - dense_dim: int - blocksize: Optional[Tuple[int, int]] - pos_dtype: torch.dtype - crd_dtype: torch.dtype +def sympy_expr_to_semi_affine_expr( + expr: sympy.Expr, symbols_map: Dict[str, AffineSymbolExpr] +) -> AffineExpr: + """Translate sympy expressions to MLIR (semi-)affine expressions. + + Recursively traverse the sympy expr AST and build the affine expr. + This is not a perfect translation. Sympy expressions are much more + expressive and not as constrained as affine (linear) expressions are. + However, for the most part, we don't need to support all of sympy. + PyTorch only uses a subset of sympy for capturing and expressing + symbolic shapes, and among what's supported, we expect the semi-affine + expressions (https://mlir.llvm.org/docs/Dialects/Affine/#semi-affine-maps) + to be sufficient. + """ + if isinstance(expr, sympy.Symbol): + return symbols_map[str(expr)] + elif isinstance(expr, (int, sympy.Integer)): + return AffineConstantExpr.get(expr) + # This handles both add (`s0 + c`) and subtract (`s0 - c`). + # The expression is `sympy.Add` in both cases but with args + # (s0, c) in first case and (s0, -c) in the second case. + elif isinstance(expr, sympy.Add): + affine_expr = AffineConstantExpr.get(0) + for arg in expr.args: + affine_expr = AffineAddExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mul): + affine_expr = AffineConstantExpr.get(1) + for arg in expr.args: + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(arg, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Pow): + base, exp = expr.args + # Only integer exponent is supported + # So, s1 ** s0 isn't allowed. + assert isinstance(exp, (int, sympy.Integer)) + assert exp > 0, "Only positive exponents supported in sympy.Pow" + affine_expr = AffineConstantExpr.get(1) + for _ in range(exp): + affine_expr = AffineMulExpr.get( + affine_expr, sympy_expr_to_semi_affine_expr(base, symbols_map) + ) + return affine_expr + elif isinstance(expr, sympy.Mod): + dividend, divisor = expr.args + return AffineModExpr.get( + sympy_expr_to_semi_affine_expr(dividend, symbols_map), + sympy_expr_to_semi_affine_expr(divisor, symbols_map), + ) + else: + raise NotImplementedError( + f"Translation of sympy.Expr of type {type(expr)} not implemented yet." + ) -def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: - """Returns sparse tensor encoding for the given sparse layout as string.""" - assert sparsity is not None +def sparsity_encoding(t: torch.Tensor) -> str: + """Returns sparse tensor encoding for the given tensor as string.""" # Sparse tensors have the form # [ , , ] # which map directly to MLIR types. - batch_dim, sparse_dim, dense_dim = ( - sparsity.batch_dim, - sparsity.sparse_dim, - sparsity.dense_dim, + dim, batch_dim, sparse_dim, dense_dim = ( + t.ndim, + t.ndim - t.sparse_dim() - t.dense_dim(), + t.sparse_dim(), + t.dense_dim(), ) - dim = batch_dim + sparse_dim + dense_dim - assert dim == len(shape) - blocksize = sparsity.blocksize - dims = ",".join(f"d{d}" for d in range(dim)) - if sparsity.layout is torch.sparse_coo: - assert sparse_dim >= 2 and blocksize is None + if t.layout is torch.sparse_coo: + assert sparse_dim >= 2 trail_dim = batch_dim + sparse_dim - 1 coords = ",".join( f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim) ) sep = "," if sparse_dim > 2 else "" lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" - elif sparsity.layout is torch.sparse_csr: - assert sparse_dim == 2 and blocksize is None + idx_dtype = t._indices().dtype # supports uncoalesced COO tensors + elif t.layout is torch.sparse_csr: + assert sparse_dim == 2 lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" - elif sparsity.layout is torch.sparse_csc: - assert sparse_dim == 2 and blocksize is None + idx_dtype = t.col_indices().dtype + elif t.layout is torch.sparse_csc: + assert sparse_dim == 2 lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" + idx_dtype = t.row_indices().dtype else: - assert sparse_dim == 2 and blocksize is not None - if sparsity.layout is torch.sparse_bsr: + assert sparse_dim == 2 + blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3] + if t.layout is torch.sparse_bsr: i, j = batch_dim, batch_dim + 1 + idx_dtype = t.col_indices().dtype else: - assert sparsity.layout is torch.sparse_bsc + assert t.layout is torch.sparse_bsc j, i = batch_dim, batch_dim + 1 + idx_dtype = t.row_indices().dtype m, n = blocksize lvls = ( f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," @@ -329,8 +436,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) lvls = f"{lvls},{dense}" - posw = torch.iinfo(sparsity.pos_dtype).bits - crdw = torch.iinfo(sparsity.crd_dtype).bits + posw = crdw = torch.iinfo(idx_dtype).bits return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" @@ -364,7 +470,7 @@ def prepare_module(self, module_op: Operation): ... def resolve_literal( - self, gni: "GraphNodeImporter", literal: Any + self, gni: "GraphNodeImporter", literal: Any, info: Optional[InputInfo] ) -> Optional[Value]: """User overridable hook to resolve a literal value.""" return None @@ -478,6 +584,7 @@ def import_program( *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Imports an ExportedProgram according to our chosen canonical representation. @@ -527,6 +634,10 @@ def import_program( sig = prog.graph_signature + # Populate symbolic guards for dynamic shapes (if any) + if import_symbolic_shape_expressions: + self._cc.set_symbolic_guards(prog) + # Invert the (producer, node_name) maps for mutated user inputs and mutated # buffers. This is because we hit-detect based on the input node name. mutated_user_inputs = { @@ -612,10 +723,17 @@ def import_program( # on a symbolic or other non-SSA association. As such, they # are not modeled with mutable IR but will trigger an output # store hook when the final value is produced. - value = prog.state_dict.get(input_spec.target) - assert ( - not input_spec.persistent or value is not None - ), "Expected state_dict value for persistent value" + if input_spec.persistent: + value = prog.state_dict.get(input_spec.target) + assert ( + value is not None + ), "Expected state_dict value for persistent buffer" + else: + value = prog.constants.get(input_spec.target) + assert ( + value is not None + ), "Expected constants value for non-persistent buffer" + node = placeholder_nodes[arg.name] mutable_producer_node_name = mutable_buffer_target_producers.get( input_spec.target @@ -682,7 +800,9 @@ def import_program( # Import all nodes and return. node_importer.import_nodes( - all_producer_nodes.values(), skip_placeholders_outputs=True + all_producer_nodes.values(), + skip_placeholders_outputs=True, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) @@ -694,6 +814,7 @@ def import_frozen_program( *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Imports a consolidated torch.export.ExportedProgram instance. @@ -728,6 +849,10 @@ def import_frozen_program( state_dict = prog.state_dict arg_replacements: Dict[str, Any] = {} + # Populate symbolic guards for dynamic shapes (if any) + if import_symbolic_shape_expressions: + self._cc.set_symbolic_guards(prog) + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 if hasattr(prog, "constants"): @@ -774,7 +899,10 @@ def import_frozen_program( g.erase_node(node) return self.import_stateless_graph( - g, func_name=func_name, func_visibility=func_visibility + g, + func_name=func_name, + func_visibility=func_visibility, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, ) def import_graph_module(self, gm: GraphModule) -> Operation: @@ -791,6 +919,7 @@ def import_stateless_graph( *, func_name: str = "main", func_visibility: Optional[str] = None, + import_symbolic_shape_expressions: bool = False, ) -> Operation: """Low-level import of a functionalized, assumed stateless Graph as a func. @@ -815,7 +944,9 @@ def import_stateless_graph( self._cc, entry_block, ) - node_importer.import_nodes(g.nodes) + node_importer.import_nodes( + g.nodes, import_symbolic_shape_expressions=import_symbolic_shape_expressions + ) self.symbol_table.insert(func) return func @@ -844,6 +975,17 @@ def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: result_types.append( IrType.parse("!torch.none", context=self._c) ) + elif isinstance(result_node, torch.Tensor): + result_types.append( + self._cc.tensor_to_vtensor_type(result_node) + ) + elif type(result_node) in SCALAR_TYPE_TO_TORCH_MLIR_TYPE: + result_types.append( + IrType.parse( + SCALAR_TYPE_TO_TORCH_MLIR_TYPE[type(result_node)], + self._c, + ) + ) else: result_types.append(self._cc.node_val_to_type(result_node)) return ( @@ -859,6 +1001,7 @@ class ContextCache: "_c", "_dtype_to_type", "_tensor_metadata_cache", + "_symbolic_guards", "_py_attr_tracker", # Types. "torch_bool_type", @@ -877,6 +1020,7 @@ def __init__( self._tensor_metadata_cache: Dict[ Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType ] = {} + self._symbolic_guards: Dict = {} self._py_attr_tracker = py_attr_tracker or RefTracker() # Common types. @@ -901,20 +1045,27 @@ def get_vtensor_type( shape: torch.Size, dtype: torch.dtype, *, - sparsity: Optional[SparsityMeta] = None, + val: Optional[torch.Tensor] = None, mutable: bool = False, ): """Return IrType for !torch.vtensor with the given shape and dtype""" stem = "torch.tensor" if mutable else "torch.vtensor" shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) - if sparsity is not None: - encoding = sparsity_encoding(shape, sparsity) - assert encoding is not None + if val is not None and val.layout in [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + ]: + # This is a sparse tensor. + encoding = sparsity_encoding(val) return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", context=self._c, ) + # This is a dense tensor. return IrType.parse( f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c ) @@ -923,24 +1074,25 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") - sparsity = node.meta.get("sparsity", None) except KeyError as e: raise RuntimeError( f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) - return self.value_info_to_type( - val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable - ) + return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable) def value_info_to_type( self, val, *, tensor_meta: Optional[TensorMetadata] = None, - sparsity=None, mutable: bool = False, ): if tensor_meta is not None: + # separately handle when tensor_meta is a list. + if isinstance(val, list) and all( + isinstance(x, TorchFakeTensor) for x in val + ): + return IrType.parse("!torch.list", context=self._c) assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, # so throw error instead of silently doing wrong thing. @@ -950,15 +1102,19 @@ def value_info_to_type( ) else: return self.tensor_metadata_to_type( - tensor_meta, sparsity=sparsity, mutable=mutable + tensor_meta, val=val, mutable=mutable ) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): return self.get_vtensor_type( - val.size(), val.dtype, sparsity=sparsity, mutable=mutable + val.size(), val.dtype, val=val, mutable=mutable ) + elif isinstance(val, list) and all( + isinstance(x, TorchFakeTensor) for x in val + ): + return IrType.parse("!torch.list", context=self._c) # Note that None is a valid scalar here, so it is important that this # is always checked as the last fallback. @@ -975,19 +1131,17 @@ def tensor_metadata_to_type( self, tm: TensorMetadata, *, - sparsity: Optional[SparsityMeta] = None, + val: Optional[torch.Tensor] = None, mutable: bool = False, ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype, sparsity, mutable) + key = (tm_shape, tm.dtype, val, mutable) t = self._tensor_metadata_cache.get(key) if t is None: - t = self.get_vtensor_type( - tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable - ) + t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable) self._tensor_metadata_cache[key] = t return t @@ -1002,9 +1156,14 @@ def dtype_to_type(self, dtype: TorchDtype) -> IrType: self._dtype_to_type[dtype] = t return t + def create_vtensor_type(self, dtype: torch.dtype, size: torch.Size) -> IrType: + dtype_asm = str(self.dtype_to_type(dtype)) + return IrType.parse( + f"!torch.vtensor<{list(size)},{dtype_asm}>", context=self._c + ) + def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType: - dtype_asm = str(self.dtype_to_type(tensor.dtype)) - return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>") + return self.create_vtensor_type(tensor.dtype, tensor.size()) def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: stack_trace = node.meta.get("stack_trace") @@ -1021,6 +1180,62 @@ def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: return Location.file(filename, line, col=0, context=self._c) return Location.unknown(context=self._c) + def set_symbolic_guards( + self, prog: torch.export.ExportedProgram + ) -> Dict[str, RangeConstraint]: + + # Recent PyTorch versions use `int_oo` to represent integer infinity. + # Older PyTorch versions like PyTorch stable version may not have + # `int_oo` defined just yet. + infs = (sympy.oo, int_oo) if int_oo is not None else (sympy.oo,) + + def _sympy_int_to_int(val: sympy.Expr, adjust_func: Callable): + # Convert simple sympy Integers into concrete int + if val in infs: + return torch.iinfo(torch.int64).max + if val in tuple(-inf for inf in infs): + return torch.iinfo(torch.int64).min + if isinstance(val, sympy.Integer): + return int(val) + # TODO: Remove this adjustment when fractional ranges are removed + return adjust_func(val) + + contains_symbolic_ints = False + sym_int_types = ( + (sympy.Integer, IntInfinity, NegativeIntInfinity) + if IntInfinity is not None + else sympy.Integer + ) + for val in prog.range_constraints.values(): + if ( + isinstance(val.lower, sym_int_types) + and isinstance(val.upper, sym_int_types) + and not val.is_bool + ): + contains_symbolic_ints = True + break + if contains_symbolic_ints: + # Build a map from shape symbol name to `RangeConstraint` object + # capturing `min_val`` and `max_val`` constraints for that + # symbol. Translate sympy integers to regular integers. + # + # Example: + # { + # 's0': RangeConstraint(min_val=5, max_val=10), + # 's1': RangeConstraint(min_val=0, max_val=100), + # 's3': RangeConstraint(min_val=0, max_val=9223372036854775806), + # } + self._symbolic_guards = { + str(k): RangeConstraint( + _sympy_int_to_int(v.lower, math.ceil), + _sympy_int_to_int(v.upper, math.floor), + ) + for k, v in prog.range_constraints.items() + } + + def get_symbolic_guards(self) -> Dict[str, RangeConstraint]: + return self._symbolic_guards + class GraphNodeImporter: """Imports graph nodes into an MLIR function. @@ -1034,7 +1249,9 @@ class GraphNodeImporter: "_cc", "_on_node_produced", "_v", + "_symbol_to_value", "_multi_result_nodes", + "_unpack_list_values", "fx_importer", ] @@ -1052,11 +1269,17 @@ def __init__( # Map of (Node, result_index) to MLIR Value or a callback that lazily # constructs and returns a value. self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} + # Map of Shape Symbol to MLIR Value + self._symbol_to_value: Dict[str, Value] = {} # Map of node name to hook that should be called when it is produced. self._on_node_produced: Dict[str, Callable[[Value], None]] = {} # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() + # If a OP returns a list, then it needs to be unpacked entirely using + # prim.ListUnpack. Cache the result of these nodes so that it only + # unpacks once instead of every time that getitem is used + self._unpack_list_values: Dict[torch_fx.Node, Tuple[Value]] = {} def bind_node_value( self, @@ -1092,6 +1315,28 @@ def resolve_node_value(self, node: Node, result_index: int = 0) -> Value: self._v[key] = value return value + def bind_symbol_value( + self, + shape_symbol: str, + value: Value, + ): + """Binds a shape symbol to a global SSA value (and asserts if already bound).""" + assert ( + shape_symbol not in self._symbol_to_value + ), f"Symbol already has a value: {shape_symbol}" + self._symbol_to_value[shape_symbol] = value + + def resolve_symbol_value(self, shape_symbol: str) -> Value: + """Resolves a shape symbol to a value.""" + try: + binding = self._symbol_to_value[shape_symbol] + except KeyError: + raise KeyError( + f"Shape symbol {shape_symbol} has not been bound to an MLIR value" + ) + if isinstance(binding, Value): + return binding + def import_mutable_to_vtensor( self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str ) -> Value: @@ -1174,10 +1419,20 @@ def return_node_values(self, loc, nodes: List[Node]): func_dialect.ReturnOp(operands, loc=loc) def import_nodes( - self, nodes: Iterable[Node], *, skip_placeholders_outputs: bool = False + self, + nodes: Iterable[Node], + *, + skip_placeholders_outputs: bool = False, + import_symbolic_shape_expressions: bool = False, ): with InsertionPoint(self._b): loc = Location.unknown() + + # Import dynamic shape symbols and guards (if any) + if import_symbolic_shape_expressions: + symbolic_guards = self._cc.get_symbolic_guards() + self._import_shape_symbols_with_guards(loc, symbolic_guards) + num_placeholders = 0 for node in nodes: op = node.op @@ -1194,29 +1449,7 @@ def import_nodes( elif op == "call_function": target = node.target if target == operator.getitem: - # Special case handling of getitem for when it is resolving - # against a function call that we know has returned multiple - # results. We short-circuit this case because we have modeled - # function calls to natively return multiple results vs tupling. - getitem_ref, getitem_index = node.args - if getitem_ref in self._multi_result_nodes: - try: - self.bind_node_value( - node, - self.resolve_node_value(getitem_ref, getitem_index), - ) - except IndexError: - raise RuntimeError( - f"getitem de-aliasing failed. This likely " - f"indicates a programmer error that usually " - f"would have happened at runtime. Please " - f"notify developers if this case happens " - f"(at {loc})." - ) - else: - raise NotImplementedError( - f"General getitem access to non-multi-result ops" - ) + self._import_getitem(loc, node) elif target in SYMBOLIC_TORCH_OPS or ( is_symbolic(node.meta.get("val")) and is_builtin_function_or_method(target) @@ -1224,7 +1457,7 @@ def import_nodes( self._import_symbolic_torch_op(loc, node, target) elif isinstance(target, TorchOpOverload): # Dispatch to an ATen op. - self._import_torch_op_overload(loc, node, target) + self._import_torch_op_overload(loc, node) elif isinstance(target, HigherOrderOperator): self._import_hop(loc, node, target) else: @@ -1237,6 +1470,9 @@ def import_nodes( operands = [self._import_argument(loc, arg) for arg in node.args[0]] func_dialect.ReturnOp(operands, loc=loc) + if import_symbolic_shape_expressions: + self._create_bind_symbolic_shape_ops(loc, node) + def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar temp_node = Node( @@ -1392,42 +1628,18 @@ def _import_hop_auto_functionalized( self.bind_node_value(node, value, i + bind_none) def _import_torch_op_overload( - self, loc: Location, node: torch_fx.Node, target: TorchOpOverload + self, + loc: Location, + node: torch_fx.Node, + concrete_target: Optional[TorchOpOverload] = None, ): - # TODO: Convert this cascade of ifs to a table-driven - # replace lift_fresh_copy with clone op - if target == torch.ops.aten.lift_fresh_copy.default: - node.target = target = torch.ops.aten.clone.default - node.args = (node.args[0],) - node.kwargs = {"memory_format": None} - elif target == torch.ops.aten.lift_fresh_copy.out: - # TODO: It seems not possible to hit this case from user code. - # Retaining in case if it is triggered internally somehow, but - # it can most likely be removed once assuming full - # functionalization in all cases. - node.target = target = torch.ops.aten.clone.out - node.args = (node.args[0],) - node.kwargs = {"memory_format": None, "out": node.args[1]} - # TODO: generalize empty.memory_format in the future - # Currently, the aten.baddbmm.default op for Unet includes multiplying an - # empty.memory_format input with a constant, which creates NaN values - # because empty.memory_format contains uninitialized data. Converting - # aten.baddbmm.default -> aten.zeros.default fixes the correctness issue - elif target == torch.ops.aten.empty.memory_format: - if len(node.users) == 1: - for key_node in node.users: - if key_node.target == torch.ops.aten.baddbmm.default: - node.target = target = torch.ops.aten.zeros.default - elif target == torch.ops.aten._local_scalar_dense.default: - input_type = node.args[0].meta["tensor_meta"].dtype - if input_type.is_floating_point: - node.target = target = torch.ops.aten.Float.Tensor - else: - node.target = target = torch.ops.aten.Int.Tensor - node.args = (node.args[0],) - elif target == torch.ops.aten._assert_async.msg: - # TODO: A more suitable op to replace it? - return + if concrete_target is None: + node = node_canonicalize(node) + if not node: + return + target = node.target + else: + target = concrete_target schema = target._schema assert isinstance(schema, FunctionSchema) @@ -1483,6 +1695,69 @@ def _import_torch_op_overload( for i, value in enumerate(operation.results): self.bind_node_value(node, value, i) + def _import_shape_symbols_with_guards( + self, loc: Location, symbolic_guards: Dict[str, RangeConstraint] + ): + for symbol, constraints in symbolic_guards.items(): + # Create torch.sym_int ops + operation = Operation.create( + name="torch.symbolic_int", + attributes={ + "symbol_name": StringAttr.get(symbol), + "min_val": self._cc.integer_attr(constraints.min_val, 64), + "max_val": self._cc.integer_attr(constraints.max_val, 64), + }, + results=[self._cc.torch_int_type], + loc=loc, + ) + self.bind_symbol_value(symbol, operation.result) + + def _create_bind_symbolic_shape_ops(self, loc: Location, node: torch_fx.Node): + node_val = node.meta.get("val") + if (node_val is not None) and isinstance(node_val, TorchFakeTensor): + # Only create bind ops if the shapes contain symbolic sizes. + # Query the bool attribute `_has_symbolic_sizes_strides` on node.meta["val"]. + if node_val._has_symbolic_sizes_strides: + # Read node metadata to obtain shape symbols and expressions + symbols_set = set() + shape_exprs = [] + for s in node_val.size(): + if isinstance(s, torch.SymInt): + symbols_set.update(s.node.expr.free_symbols) + shape_exprs.append(s.node.expr) + else: + assert isinstance(s, int) + shape_exprs.append(s) + + # Map from sympy shape symbols to local symbols in the affine map + symbols_set = sorted(symbols_set, key=lambda x: x.name) + symbols_map = { + str(symbol): AffineSymbolExpr.get(i) + for i, symbol in enumerate(symbols_set) + } + + # Convert symbolic shape expressions into affine expressions + affine_exprs = [ + sympy_expr_to_semi_affine_expr(expr, symbols_map) + for expr in shape_exprs + ] + + affine_map = AffineMap.get(0, len(symbols_set), affine_exprs) + + # Build operand list + operand_list = [] + operand_list.append(self.resolve_node_value(node)) + for symbol in symbols_map.keys(): + operand_list.append(self.resolve_symbol_value(symbol)) + + # Create torch.bind_symbolic_shape ops + Operation.create( + name="torch.bind_symbolic_shape", + attributes={"shape_expressions": AffineMapAttr.get(affine_map)}, + operands=operand_list, + loc=loc, + ) + def _import_argument( self, loc: Location, arg: NodeArgument, expected_jit_type=None ) -> Value: @@ -1511,41 +1786,66 @@ def _import_argument( ): # promote scalars to tensor types as appropriate argument_value = self._import_scalar_as_tensor(loc, arg) - else: + elif LITERAL_CONVERTER_MAP.lookup(type(arg)) is not None: with loc: argument_value = self._import_literal(arg) - return self._convert_type(loc, argument_value, expected_jit_type) + else: + raise TypeError(f"Unsupported argument type {arg.__class__}") + with loc: + return self._convert_type(argument_value, expected_jit_type) - def _convert_type(self, loc: Location, val: Value, expected_jit_type): + def _convert_type( + self, + val: Value, + expected_type, + dtype: Optional[torch.dtype] = None, + size: Optional[torch.Size] = None, + ): """ When the type of 'value' and the type in the schema do not match, attempt to perform automatic type conversion. example: test/python/fx_importer/basic_test.py::test_full """ + if not expected_type: + return val op_name = None result_type = None # TODO: If additional types require conversion in the future, # consider implementing a table-driven approach. + operands = [val] if val.type == self._cc.torch_bool_type: - if isinstance(expected_jit_type, torch.FloatType): + if isinstance(expected_type, torch.FloatType): op_name = "torch.aten.Float.bool" result_type = self._cc.torch_float_type - elif isinstance(expected_jit_type, (torch.IntType, torch.NumberType)): + elif isinstance(expected_type, (torch.IntType, torch.NumberType)): op_name = "torch.aten.Int.bool" result_type = self._cc.torch_int_type + elif expected_type is torch.Tensor: + op_name = "torch.prims.convert_element_type" + result_type = self._cc.create_vtensor_type(dtype, size) + operands.append( + LITERAL_CONVERTER_MAP.lookup(torch.dtype)(dtype, self, self._cc) + ) if op_name is None: return val - with loc: - return Operation.create( - name=op_name, results=[result_type], operands=[val] - ).result + return Operation.create( + name=op_name, results=[result_type], operands=operands + ).result - def _import_literal(self, py_value: Any) -> Value: + def _import_literal(self, py_value: Any, info: Optional[InputInfo] = None) -> Value: + orig_value = None + if isinstance(py_value, torch.Tensor) and py_value.dtype == torch.bool: + orig_value = py_value + py_value = py_value.to(torch.uint8) # Apply the conversion callback. - user_value = self.fx_importer._hooks.resolve_literal(self, py_value) + user_value = self.fx_importer._hooks.resolve_literal(self, py_value, info) if user_value is not None: assert isinstance(user_value, Value) + if orig_value is not None: + user_value = self._convert_type( + user_value, torch.Tensor, orig_value.dtype, orig_value.size() + ) return user_value # Default conversion path. @@ -1554,7 +1854,12 @@ def _import_literal(self, py_value: Any) -> Value: raise TypeError( f"Unsupported argument -> literal conversion for {py_value.__class__}" ) - return converter(py_value, self, self._cc) + result = converter(py_value, self, self._cc) + if orig_value is not None: + result = self._convert_type( + result, torch.Tensor, orig_value.dtype, orig_value.size() + ) + return result def _import_input(self, py_value: Any, info: InputInfo) -> Value: # Try the hook. @@ -1568,7 +1873,7 @@ def _import_input(self, py_value: Any, info: InputInfo) -> Value: raise ValueError( f"Cannot import {info.input_spec} as a literal because it is mutable" ) - return self._import_literal(py_value) + return self._import_literal(py_value, info) def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: tensor_arg = torch.tensor(arg) @@ -1668,6 +1973,51 @@ def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: with loc: return cvt(arg, self, self._cc) + def _import_getitem(self, loc: Location, node: torch.fx.Node): + ref_node, index = node.args + if ref_node in self._multi_result_nodes: + # Special case handling of getitem for when it is resolving + # against a function call that we know has returned multiple + # results. We short-circuit this case because we have modeled + # function calls to natively return multiple results vs tupling. + try: + self.bind_node_value( + node, + self.resolve_node_value(ref_node, index), + ) + except IndexError: + raise RuntimeError( + f"getitem de-aliasing failed. This likely " + f"indicates a programmer error that usually " + f"would have happened at runtime. Please " + f"notify developers if this case happens " + f"(at {loc})." + ) + else: + # handle nodes that return a torch.list<...> at the MLIR level + # NOTE: the length of the list must be knowable at compile time. + if ref_node not in self._unpack_list_values: + node_result = self.resolve_node_value(ref_node, 0) + if str(node_result.type) in TORCH_LIST_TYPES: + result_types = [ + self._cc.value_info_to_type(v) for v in ref_node.meta["val"] + ] + operation = Operation.create( + "torch.prim.ListUnpack", + results=result_types, + operands=[node_result], + loc=loc, + ) + self._unpack_list_values[ref_node] = tuple(operation.results) + + try: + self.bind_node_value(node, self._unpack_list_values[ref_node][index]) + except IndexError: + raise RuntimeError( + f"getitem failed. " + f"getitem only supports lists of known length. (at {loc})" + ) + def _unpack_node_result_types( self, node: torch.fx.Node, schema: FunctionSchema ) -> List[IrType]: @@ -1702,16 +2052,19 @@ def _make_constant_op( ) -def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: +def _create_mlir_tensor_type(dtype: torch.dtype, size: torch.Size) -> IrType: try: - dtype = tensor.dtype element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() - tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) + tensor_type = RankedTensorType.get(size, element_type) return tensor_type except KeyError: raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") +def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: + return _create_mlir_tensor_type(tensor.dtype, tensor.size()) + + def _make_vtensor_literal_op( tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker" ) -> Operation: @@ -1894,13 +2247,15 @@ def track(self, referrent: Any) -> RefMapping: if existing: return existing info = RefMapping(referrent) - if referrent is not Empty: - weakref.finalize(referrent, self._ref_finalizer, ref_id) + # Finalizer is removed due to a memory leak + # See: https://github.com/iree-org/iree-turbine/issues/281 + # if referrent is not Empty: + # weakref.finalize(referrent, self._ref_finalizer, ref_id) self._refs[ref_id] = info return info - def _ref_finalizer(self, ref_id: int): - del self._refs[ref_id] + # def _ref_finalizer(self, ref_id: int): + # del self._refs[ref_id] ################################################################################ @@ -1995,6 +2350,10 @@ def _ref_finalizer(self, ref_id: int): "vtensor": "!torch.list>", } +TORCH_LIST_TYPES = set(PY_TYPE_TO_TORCH_LIST_TYPE.values()) | set( + PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE.values() +) + SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { torch.SymInt: "!torch.int", torch.SymFloat: "!torch.float", @@ -2016,3 +2375,97 @@ def _ref_finalizer(self, ref_id: int): "torch.aten.sub.Tensor": "torch.aten.sub.Scalar", "torch.aten.floor_divide": "torch.aten.floor_divide.Scalar", } + + +NODE_CANONICALIZE: Dict[TorchOpOverload, Callable] = {} + + +def register_canonicalize(op: TorchOpOverload): + def wrapper(func): + NODE_CANONICALIZE[op] = func + return func + + return wrapper + + +@register_canonicalize(torch.ops.aten.lift_fresh_copy.default) +def lift_fresh_copy_default(node: torch_fx.Node): + # replace lift_fresh_copy with clone op + node.target = torch.ops.aten.clone.default + node.args = (node.args[0],) + node.kwargs = {"memory_format": None} + return node + + +@register_canonicalize(torch.ops.aten.lift_fresh_copy.out) +def lift_fresh_copy_out(node: torch_fx.Node): + # TODO: It seems not possible to hit this case from user code. + # Retaining in case if it is triggered internally somehow, but + # it can most likely be removed once assuming full + # functionalization in all cases. + node.target = target = torch.ops.aten.clone.out + node.args = (node.args[0],) + node.kwargs = {"memory_format": None, "out": node.args[1]} + return node + + +@register_canonicalize(torch.ops.aten.empty.memory_format) +def empty_memory_format(node: torch_fx.Node): + # TODO: generalize empty.memory_format in the future + # Currently, the aten.baddbmm.default op for Unet includes multiplying an + # empty.memory_format input with a constant, which creates NaN values + # because empty.memory_format contains uninitialized data. Converting + # aten.baddbmm.default -> aten.zeros.default fixes the correctness issue + if len(node.users) == 1: + for key_node in node.users: + if key_node.target == torch.ops.aten.baddbmm.default: + node.target = torch.ops.aten.zeros.default + return node + + +@register_canonicalize(torch.ops.aten._local_scalar_dense.default) +def aten__local_scalar_dense_default(node: torch_fx.Node): + input_type = node.args[0].meta["tensor_meta"].dtype + if input_type.is_floating_point: + node.target = torch.ops.aten.Float.Tensor + else: + node.target = torch.ops.aten.Int.Tensor + node.args = (node.args[0],) + return node + + +@register_canonicalize(torch.ops.aten._assert_async.msg) +def aten__assert_async_msg(node: torch_fx.Node): + # TODO: A more suitable op to replace it? + return None + + +@register_canonicalize(torch.ops.aten._unsafe_index_put.default) +def aten__unsafe_index_put_default(node: torch_fx.Node): + node.target = torch.ops.aten._unsafe_index_put.hacked_twin + return node + + +@register_canonicalize(torch.ops.aten._embedding_bag_forward_only.default) +def aten__embedding_bag_forward_only_default(node: torch_fx.Node): + node.target = torch.ops.aten.embedding_bag.padding_idx + embedding_bag_args = [ + ("scale_grad_by_freq", False), + ("mode", 0), + ("sparse", False), + ("per_sample_weights", None), + ("include_last_offset", False), + ("padding_idx", None), + ] + node_kwargs = dict(node.kwargs) + for k, v in embedding_bag_args[len(node.args) - 3 :]: + if k not in node_kwargs: + node_kwargs[k] = v + node.kwargs = node_kwargs + return node + + +def node_canonicalize(node: torch_fx.Node): + if node.target in NODE_CANONICALIZE: + return NODE_CANONICALIZE[node.target](node) + return node diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index e0d3529d942e..7ce3647ee8c4 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -34,8 +34,9 @@ ) from e from typing import Optional, List, Dict, Tuple +import warnings -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np import re @@ -90,6 +91,45 @@ class Config: # making an assumption. elide_initialized_inputs: bool = True + # Some ONNX operators are defined by ONNX functions and will be + # automatically expanded (see get_operator_function() below) to MLIR + # functions by the importer. This option allows allowlisting functions that + # should be expanded. If this is None, then allowlisting is not used (all + # functions not explicitly denylisted will be expanded). + # + # Since function expansion has not always been supported, the default should + # be to use allowlisting, to avoid disruption. + function_expansion_allowlists_by_domain: Optional[Dict[str, set[str]]] = field( + default_factory=lambda: { + # Default domain (ONNX built-in ops) + "": { + "MeanVarianceNormalization", + } + } + ) + + # Some ONNX operators are defined by ONNX functions and will be + # automatically expanded (see get_operator_function() below) to MLIR + # functions by the importer. This option allows denylisting functions that + # should not be expanded. + function_expansion_denylists_by_domain: Dict[str, set[str]] = field( + default_factory=lambda: { + # Default domain (ONNX built-in ops) + "": { + # CastLike's second input `target_type` is used only for its + # type (T2), from which its output's type is inferred, but + # because its value is unused, ONNX's shape inference doesn't + # annotate the input value with a type, so looking up the + # function by the provided input types will fail. + "CastLike", + # ONNX errors when trying to infer the type of the Loop op + # within this function: "[ShapeInferenceError] Inferred shape + # and existing shape differ in rank: (1) vs (0)" + "Range", + } + } + ) + class ModelInfo: """Top-level accounting and accessors for an ONNX model.""" @@ -111,7 +151,12 @@ def create_module(self, context: Optional[Context] = None) -> Module: class GraphInfo: """Information about a Graph within a model.""" - def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): + def __init__( + self, + model_info: ModelInfo, + graph_proto: onnx.GraphProto, + is_subgraph: bool = False, + ): self.model_info = model_info self.graph_proto = graph_proto self.initializer_map: Dict[str, onnx.TensorProto] = { @@ -129,7 +174,11 @@ def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): # Generate the effective input map, which for old models can be a # subset of the input map. - if model_info and model_info.config.elide_initialized_inputs: + if ( + not is_subgraph + and model_info + and model_info.config.elide_initialized_inputs + ): self.input_map = { k: v for k, v in self.declared_input_map.items() @@ -149,9 +198,20 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: # Node outputs don't typically have type information, but shape inference # will associate them in the value_info. If not there, it may be a # graph output, which must have type information. - value_info = self.value_info_map.get(name) or self.output_map.get(name) + value_info = ( + self.value_info_map.get(name) + or self.output_map.get(name) + or self.declared_input_map.get(name) + ) if value_info is not None: return value_info.type + + tensor_proto = self.initializer_map.get(name) + if tensor_proto is not None: + return onnx.helper.make_tensor_type_proto( + tensor_proto.data_type, tensor_proto.dims + ) + # No type information is associated, this can occur when the value is unused: return "" @@ -172,6 +232,8 @@ class NodeImporter: __slots__ = [ "_c", "_cc", + "_m", + "_mc", "_gi", "_p", "_b", @@ -185,9 +247,13 @@ def __init__( parent_op: Operation, block: Block, context_cache: "ContextCache", + module_op: Operation, + module_cache: "ModuleCache", ): self._c = parent_op.context self._cc = context_cache + self._m = module_op + self._mc = module_cache self._gi = graph_info self._p = parent_op self._b = block @@ -195,9 +261,19 @@ def __init__( @classmethod def define_function( - cls, graph_info: GraphInfo, module_op: Operation + cls, + graph_info: GraphInfo, + module_op: Operation, + context_cache: Optional["ContextCache"] = None, + module_cache: Optional["ModuleCache"] = None, + private: bool = False, ) -> "NodeImporter": - cc = ContextCache(module_op.context) + cc = ( + context_cache + if context_cache is not None + else ContextCache(module_op.context) + ) + mc = module_cache if module_cache is not None else ModuleCache(module_op, cc) with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): body = module_op.regions[0].blocks[0] func_name = graph_info.graph_proto.name @@ -209,11 +285,23 @@ def define_function( for out in graph_info.output_map.values() ] ftype = FunctionType.get(input_types, output_types) - func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body)) + func_op = func_dialect.FuncOp( + func_name, + ftype, + ip=InsertionPoint(body), + visibility="private" if private else None, + ) block = func_op.add_entry_block( [Location.name(k) for k in graph_info.input_map.keys()] ) - imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) + imp = NodeImporter( + graph_info, + parent_op=func_op, + block=block, + context_cache=cc, + module_op=module_op, + module_cache=mc, + ) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value imp._populate_graph_attrs(func_op) @@ -293,6 +381,8 @@ def get_none(self): def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type + op_domain = node.domain + # Handle special op types that materialize to non-op IR constructs. # Handlers return True if the op was handled, else this function # should process it as a general node. @@ -303,33 +393,58 @@ def import_node(self, node: onnx.NodeProto): return # General node import. input_values = [] + input_type_protos = [] for input_name in node.input: try: input_values.append(self._nv_map[input_name]) + # Missing optional arguments will have empty types + input_type_protos.append( + self._gi.find_type_proto_for_name(input_name) + or onnx.TypeProto() + ) except KeyError: raise OnnxImportError( f"Non topologically produced ONNX node input '{input_name}': {node}" ) - output_names = list(node.output) - output_types = [ - self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n)) - for n in output_names - ] - - attrs = self.import_attributes(node.attribute) - attrs["name"] = StringAttr.get(f"onnx.{op_type}") - regions = self.count_regions(node.attribute) - - custom_op = Operation.create( - name="torch.operator", - results=output_types, - operands=input_values, - attributes=attrs, - regions=regions, + output_names = [] + output_type_protos = [] + output_types = [] + for output_name in node.output: + output_names.append(output_name) + type_proto = self._gi.find_type_proto_for_name(output_name) + output_type_protos.append(type_proto) + output_types.append(self._cc.type_proto_to_type(type_proto)) + + for opset_import in self._gi.model_info.model_proto.opset_import: + if opset_import.domain == op_domain: + opset_version = opset_import.version + break + operator_func_op = self._mc.get_operator_function( + op_type, + op_domain, + opset_version, + input_type_protos, + output_type_protos, + node, + self._gi.model_info.config, ) - self.import_regions(node.attribute, custom_op) + if operator_func_op is not None: + custom_op = func_dialect.CallOp(operator_func_op, input_values) + else: + attrs = self.import_attributes(node.attribute) + attrs["name"] = StringAttr.get(f"onnx.{op_type}") + regions = self.count_regions(node.attribute) + custom_op = Operation.create( + name="torch.operator", + results=output_types, + operands=input_values, + attributes=attrs, + regions=regions, + ) + self.import_regions(node.attribute, custom_op) + for output_name, output_value in zip(output_names, custom_op.results): self._nv_map[output_name] = output_value @@ -387,9 +502,14 @@ def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op): *block_types, arg_locs=[op.location] * len(block_types) ) block = region.blocks[0] - graph_info = GraphInfo(None, attr.g) + graph_info = GraphInfo(self._gi.model_info, attr.g, is_subgraph=True) imp = NodeImporter( - graph_info, parent_op=op, block=block, context_cache=self._cc + graph_info, + parent_op=op, + block=block, + context_cache=self._cc, + module_op=self._m, + module_cache=self._mc, ) for node_name, input_value in zip(block_names, block.arguments): @@ -579,6 +699,10 @@ def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: if tp == "": + warnings.warn( + "Found a node without a valid type proto. Consider updating the opset_version of" + " the model and/or running the importer with the flag '--clear-domain'." + ) return self.get_none_type() tt = tp.tensor_type @@ -603,6 +727,11 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: element_type = self.get_optional_element_type(ot.elem_type) return self.get_optional_type(element_type) + # Check if TypeProto is empty (sometimes happens for unused function + # arguments) + if tp.WhichOneof("value") is None: + return self.get_none_type() + # TODO: Others if ever needed. Or we consider ourselves DNN-only. # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") @@ -610,7 +739,10 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: def _sanitize_name(self, name): if not name.isidentifier(): name = "_" + name - return re.sub("[:/]", "_", name) + + # Remove characters that are invalid in MLIR identifier names. + # https://mlir.llvm.org/docs/LangRef/#identifiers-and-keywords + return re.sub("[^\w\.]", "_", name) def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: tensor_type = self.tensor_proto_to_builtin_type(tp) @@ -631,6 +763,323 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: return handler(tp) +def _shallow_copy_and_clear_protobuf_list(protobuf_list) -> list: + """ + Workaround for .clear() not being available on protobuf lists for some + reason. + """ + copy = list(protobuf_list) + while len(protobuf_list) > 0: + protobuf_list.pop() + return copy + + +def _bind_attributes_on_node( + interior_node: onnx.NodeProto, + caller_node: onnx.NodeProto, + op_schema: onnx.defs.OpSchema, +) -> onnx.NodeProto: + """ + Helper for _specialize_function_and_create_model() that binds concrete + values to an attributes on a node in the interior of a function. + + This should behave the same as ONNX's C++ attribute binder, please use it as + a reference: https://github.com/onnx/onnx/blob/88f8ef15cfaa3138d336f3502aed5018d802bf43/onnx/shape_inference/attribute_binder.h#L15-L64 + """ + + def _bind_attributes_in_subgraph( + old_subgraph: onnx.GraphProto, + caller_node: onnx.NodeProto, + op_schema: onnx.defs.OpSchema, + ) -> onnx.GraphProto: + """ + Recurse to bind attributes in a subgraph. + """ + new_subgraph.CopyFrom(old_subgraph) + old_nodes = _shallow_copy_and_clear_protobuf_list(new_subgraph.node) + for old_node in old_nodes: + new_subgraph.node.append( + _bind_attributes_on_node(old_node, caller_node, op_schema) + ) + return new_subgraph + + def _bind_attribute( + old_attribute: onnx.AttributeProto, + caller_node: onnx.NodeProto, + op_schema: onnx.defs.OpSchema, + ) -> Optional[onnx.AttributeProto]: + """ + Bind a single attribute. + + Bound values either come from attributes on the node calling the + function, or from default values. If the attribute is optional and has + no default value, and no value was provided by the caller, None is + returned and the attribute should be removed. + """ + + ref_name = old_attribute.ref_attr_name + if not ref_name: + if not old_attribute.g or len(old_attribute.graphs) == 0: + return old_attribute + + # Recurse to bind attributes on subgraphs. ONNX's implementation of + # attribute binding only does this for subgraphs that didn't come + # from a referenced attribute value, so this code doesn't either. + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom(old_attribute) + if new_attribute.g: + new_attribute.g = _bind_attributes_in_subgraph( + new_attribute.g, caller_node, op_schema + ) + if new_attribute.graphs: + old_subgraphs = _shallow_copy_and_clear_protobuf_list( + new_attribute.graphs + ) + for old_subgraph in old_subgraphs: + new_attribute.graphs.append( + _bind_attributes_in_subgraph( + old_subgraph, caller_node, op_schema + ) + ) + return new_attribute + + for call_attribute in caller_node.attribute: + if call_attribute.name == ref_name: + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom(call_attribute) + new_attribute.name = old_attribute.name + return new_attribute + + # The default value is sometimes empty for optional attributes + # that don't have a default, in which case it is dropped. + default_value = op_schema.attributes[ref_name].default_value + if default_value and default_value.type: + new_attribute = onnx.AttributeProto() + new_attribute.CopyFrom(default_value) + new_attribute.name = old_attribute.name + return new_attribute + + return None + + new_node = onnx.NodeProto() + new_node.CopyFrom(interior_node) + old_attributes = _shallow_copy_and_clear_protobuf_list(new_node.attribute) + for node_attribute in old_attributes: + new_attribute = _bind_attribute(node_attribute, caller_node, op_schema) + if new_attribute is not None: + new_node.attribute.append(new_attribute) + continue + return new_node + + +def _specialize_function_and_create_model( + function_proto: onnx.FunctionProto, + op_schema: onnx.defs.OpSchema, + name_to_give_model: str, + input_type_protos: list[onnx.TypeProto], + output_type_protos: list[onnx.TypeProto], + caller_node: onnx.NodeProto, +) -> onnx.ModelProto: + """ + Helper for ModuleCache::get_operator_function() that specializes a function + and coverts it to a model. + + An ONNX function may be polymorphic, parameterized over the types of its + inputs and values of its attributes (~= compile-time constants). We need to + monomorphize it for importing into MLIR. It seems like the only practical + way to do this is by turning it into a model: + - models can have types on their inputs and outputs, unlike functions + - ONNX provides a function to do shape inference (providing concrete + types for everything in the body) for models, but not for functions + - the rest of the code in this importer can only handle models, not + functions + """ + + graph_proto = onnx.GraphProto() + + for input_name, input_type_proto in zip(function_proto.input, input_type_protos): + input_proto = onnx.ValueInfoProto() + input_proto.name = input_name + input_proto.type.CopyFrom(input_type_proto) + graph_proto.input.append(input_proto) + output_proto = onnx.ValueInfoProto() + + for output_name, output_type_proto in zip( + function_proto.output, output_type_protos + ): + output_proto.name = output_name + output_proto.type.CopyFrom(output_type_proto) + graph_proto.output.append(output_proto) + + for node in function_proto.node: + # Import referenced attributes from call-site or default values + graph_proto.node.append(_bind_attributes_on_node(node, caller_node, op_schema)) + + graph_proto.name = name_to_give_model + + model_proto = onnx.ModelProto() + model_proto.opset_import.extend(function_proto.opset_import) + # FIXME: is this the correct IR version, or should it be the latest, or the + # one used by the actual model, or something else? + model_proto.ir_version = onnx.helper.find_min_ir_version_for( + function_proto.opset_import + ) + model_proto.graph.CopyFrom(graph_proto) + + model_proto = onnx.shape_inference.infer_shapes( + model_proto, check_type=True, strict_mode=True, data_prop=True + ) + graph_proto = model_proto.graph + + # Useful for debugging. + # onnx.checker.check_model(model_proto, full_check=True) + + return model_proto + + +class ModuleCache: + """Caches per-module lookups of various things.""" + + __slots__ = [ + "_m", + "_cc", + "_operator_function_map", + ] + + def __init__(self, module_op: Operation, context_cache: ContextCache): + self._m = module_op + self._cc = context_cache + self._operator_function_map: Dict[str, func_dialect.FuncOp] = {} + + def get_operator_function( + self, + op_name: str, + op_domain: str, + opset_version: int, + input_type_protos: list[onnx.TypeProto], + output_type_protos: list[onnx.TypeProto], + caller_node: onnx.NodeProto, + config: Config, + ) -> Optional[func_dialect.FuncOp]: + """ + Get or create MLIR function corresponding to an ONNX operator. + + Returns None for ONNX operators that aren't functions. + """ + + allowlists = config.function_expansion_allowlists_by_domain + denylists = config.function_expansion_denylists_by_domain + + if allowlists is not None and not ( + op_domain in allowlists and op_name in allowlists[op_domain] + ): + return None + + if op_domain in denylists and op_name in denylists[op_domain]: + return None + + op_schema = onnx.defs.get_schema( + op_name, domain=op_domain, max_inclusive_version=opset_version + ) + + # The get_schema() lookup above should get the right version of the + # operator definition, but the function body can change slightly + # within a single operator version, as explained in + # https://github.com/onnx/onnx/blob/093a8d335a66ea136eb1f16b3a1ce6237ee353ab/onnx/defs/schema.h#L1070-L1086 + # There also seem to be cases where a function goes from being not + # context-dependent to context-dependent. + f = lambda ver: ver <= opset_version + ncd_function_version = max( + filter(f, op_schema.function_opset_versions), + default=None, + ) + cd_function_version = max( + filter(f, op_schema.context_dependent_function_opset_versions), + default=None, + ) + if ncd_function_version is None and cd_function_version is None: + # No relevant function definition + return None + if ncd_function_version is not None and ( + cd_function_version is None or cd_function_version < ncd_function_version + ): + specific_version = ncd_function_version + is_context_dependent = False + else: + specific_version = cd_function_version + is_context_dependent = True + + # This is both a key for memoization of function importing and also a + # name mangling scheme, so it must include all information needed to + # uniquely identify a function and anything it might be parameterized + # over. + key = repr( + ( + op_name, + op_domain, + opset_version, + input_type_protos, + # Though output types can be inferred from input types, it does + # not seem to be the case that there's only one legal set of + # outputs for a given set of inputs. When attemtping to always + # use onnx.shape_inference.infer_function_output_types instead + # of the caller-provided types, sometimes IR verification fails + output_type_protos, + # Avoid including the attributes twice (once on their own and + # once as part of the node) for context-dependent functions, + # avoid including unused parts of the node for other functions. + caller_node if is_context_dependent else caller_node.attribute, + ) + ) + + existing = self._operator_function_map.get(key) + if existing is not None: + return existing + + if is_context_dependent: + function_proto_str = ( + op_schema.get_context_dependent_function_with_opset_version( + specific_version, + caller_node.SerializeToString(), + [ + t.SerializeToString() if not isinstance(t, bytes) else t + for t in input_type_protos + ], + ) + ) + else: + function_proto_str = op_schema.get_function_with_opset_version( + specific_version + ) + if not function_proto_str: + raise OnnxImportError( + f"Function lookup for {op_name}/{op_domain}/{specific_version}/{is_context_dependent} failed unexpectedly. This probably indicates a bug." + ) + function_proto = onnx.onnx_pb.FunctionProto() + function_proto.ParseFromString(function_proto_str) + + tmp_model_proto = _specialize_function_and_create_model( + function_proto, + op_schema, + key, + input_type_protos, + output_type_protos, + caller_node, + ) + + tmp_model_info = ModelInfo(tmp_model_proto) + tmp_graph_info = GraphInfo(tmp_model_info, tmp_model_proto.graph) + # Mark function as private so it will be thrown away after inlining + imp = NodeImporter.define_function( + tmp_graph_info, self._m, self._cc, self, private=True + ) + imp.import_all() + func_op = imp._p + + self._operator_function_map[key] = func_op + return func_op + + ELEM_TYPE_TO_IR_TYPE_CB = { onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(), onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8), @@ -652,6 +1101,8 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), onnx.TensorProto.DataType.STRING: lambda: "!torch.str", + onnx.TensorProto.DataType.UINT4: lambda: IntegerType.get_unsigned(4), + onnx.TensorProto.DataType.INT4: lambda: IntegerType.get_signed(4), # Ommitted: STRING, } @@ -688,6 +1139,9 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: ), signless=False, ), + onnx.TensorProto.DataType.UINT8: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.uint8).reshape(tp.dims), signless=False + ), onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get( np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False ), diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 834cffd63ff0..192533729d94 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -4,6 +4,7 @@ # Also available under a BSD-style license. See LICENSE. from typing import Optional, Union, Dict, Tuple, Any, Callable +from packaging import version import warnings @@ -12,11 +13,11 @@ import torch.nn as nn from torch.export import ExportedProgram -from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks -from torch_mlir import ir -from torch_mlir.dialects import torch as torch_d -from torch_mlir.extras.fx_decomp_util import get_decomposition_table -from torch_mlir.compiler_utils import ( +from .extras.fx_importer import FxImporter, FxImporterHooks +from . import ir +from .dialects import torch as torch_d +from .extras.fx_decomp_util import get_decomposition_table +from .compiler_utils import ( OutputType, run_pipeline_with_repro_report, lower_mlir_module, @@ -27,31 +28,36 @@ def _module_lowering( verbose, output_type, torch_mod, - backend_legal_ops=None, extra_library_file_name=None, + backend_legal_ops=None, ): - if output_type == OutputType.TORCH: + if output_type == OutputType.RAW: if verbose: print(torch_mod) return torch_mod - # TODO: pass backend_legal_ops/extra_library_file_name by caller - if backend_legal_ops is None: - backend_legal_ops = [] + # TODO: pass extra_library_file_name by caller + + backend_legal_op_arg_str = "" + if backend_legal_ops is not None: + if not len(backend_legal_ops) == 0: + backend_legal_op_arg_str = "backend-legal-ops=" + ",".join( + backend_legal_ops + ) + if extra_library_file_name is None: extra_library_file_name = "" option_string = ( - "{backend-legal-ops=" - + ",".join(backend_legal_ops) + "{" + + backend_legal_op_arg_str + " extra-library=" + extra_library_file_name - + " shape-dtype-refine=" - + ("false" if not backend_legal_ops and not extra_library_file_name else "true") + "}" ) + run_pipeline_with_repro_report( torch_mod, - f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})", + f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})", "Lowering TorchFX IR -> Torch Backend IR", enable_ir_printing=verbose, ) @@ -61,15 +67,17 @@ def _module_lowering( def export_and_import( f: Union[nn.Module, ExportedProgram], *args, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, + import_symbolic_shape_expressions: bool = False, hooks: Optional[FxImporterHooks] = None, decomposition_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, func_name: str = "main", enable_graph_printing: bool = False, enable_ir_printing: bool = False, + backend_legal_ops: Optional[list[str]] = None, **kwargs, ): context = ir.Context() @@ -80,7 +88,11 @@ def export_and_import( if isinstance(f, ExportedProgram): prog = f else: - prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) + # pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export + if version.Version(torch.__version__) >= version.Version("2.2.0"): + prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) + else: + prog = torch.export.export(f, args, kwargs) if decomposition_table is None: decomposition_table = get_decomposition_table() if decomposition_table: @@ -90,23 +102,35 @@ def export_and_import( if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") - fx_importer.import_program(prog, func_name=func_name) + fx_importer.import_program( + prog, + func_name=func_name, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) else: - fx_importer.import_frozen_program(prog, func_name=func_name) + fx_importer.import_frozen_program( + prog, + func_name=func_name, + import_symbolic_shape_expressions=import_symbolic_shape_expressions, + ) return _module_lowering( - enable_ir_printing, OutputType.get(output_type), fx_importer.module + enable_ir_printing, + OutputType.get(output_type), + fx_importer.module, + backend_legal_ops=backend_legal_ops, ) def stateless_fx_import( gm: torch.fx.GraphModule, - output_type: Union[str, OutputType] = OutputType.TORCH, + output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, hooks: Optional[FxImporterHooks] = None, model_name: str = "main", enable_graph_printing: bool = False, enable_ir_printing: bool = False, + backend_legal_ops: Optional[list[str]] = None, ): if enable_graph_printing: gm.print_readable() @@ -116,5 +140,8 @@ def stateless_fx_import( fx_importer = FxImporter(context=context, hooks=hooks) fx_importer.import_stateless_graph(gm.graph, func_name=model_name) return _module_lowering( - enable_ir_printing, OutputType.get(output_type), fx_importer.module + enable_ir_printing, + OutputType.get(output_type), + fx_importer.module, + backend_legal_ops=backend_legal_ops, ) diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 92ae3c7eb356..4f852d34bb0a 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -20,6 +20,7 @@ import sys import onnx +import onnx.version from ...extras import onnx_importer @@ -30,10 +31,14 @@ def main(args: argparse.Namespace): + config = onnx_importer.Config() + if args.disable_function_expansion_allowlist: + config.function_expansion_allowlists_by_domain = None + model_proto = load_onnx_model(args) context = Context() torch_d.register_dialect(context) - model_info = onnx_importer.ModelInfo(model_proto) + model_info = onnx_importer.ModelInfo(model_proto, config=config) m = model_info.create_module(context=context).operation imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() @@ -79,7 +84,17 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: raw_model = onnx.load(args.input_file) else: raw_model = onnx.load(args.input_file, load_external_data=False) - onnx.load_external_data_for_model(raw_model, args.data_dir) + onnx.load_external_data_for_model(raw_model, str(args.data_dir)) + + if args.opset_version: + raw_model = onnx.version_converter.convert_version( + raw_model, args.opset_version + ) + + if args.clear_domain: + graph = raw_model.graph + for n in graph.node: + n.ClearField("domain") # Run the checker to test whether the file is above the threshold for # in-memory shape inference. If not, go ahead and do the shape inference. @@ -122,7 +137,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: # Load the temp file and the external data. inferred_model = onnx.load(temp_inferred_file, load_external_data=False) data_dir = Path(input_dir if args.temp_dir is None else args.data_dir) - onnx.load_external_data_for_model(inferred_model, data_dir) + onnx.load_external_data_for_model(inferred_model, str(data_dir)) # Remove the inferred shape file unless asked to keep it if not args.keep_temps: @@ -149,6 +164,14 @@ def parse_arguments(argv=None) -> argparse.Namespace: action=argparse.BooleanOptionalAction, help="Toggle data propogation for onnx shape inference", ) + parser.add_argument( + "--clear-domain", + dest="clear_domain", + default=False, + action=argparse.BooleanOptionalAction, + help="If enabled, this will clear the domain attribute from each node" + " in the onnx graph before performing shape inference.", + ) parser.add_argument( "--keep-temps", action="store_true", help="Keep intermediate files" ) @@ -170,6 +193,18 @@ def parse_arguments(argv=None) -> argparse.Namespace: " Defaults to the directory of the input file.", type=Path, ) + parser.add_argument( + "--opset-version", + help="Allows specification of a newer opset_version to update the model" + " to before importing to MLIR. This can sometime assist with shape inference.", + type=int, + ) + parser.add_argument( + "--disable-function-expansion-allowlist", + action="store_true", + help="Disable the allowlist for ONNX function expansion," + " allowing non-allowlisted functions to be expanded.", + ) args = parser.parse_args(argv) return args diff --git a/python/torch_mlir/tools/opt/__main__.py b/python/torch_mlir/tools/opt/__main__.py new file mode 100644 index 000000000000..26cd61402878 --- /dev/null +++ b/python/torch_mlir/tools/opt/__main__.py @@ -0,0 +1,40 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Torch-MLIR modular optimizer driver + +Typically, when installed from a wheel, this can be invoked as: + + torch-mlir-opt [options] + +To see available passes, dialects, and options, run: + + torch-mlir-opt --help +""" +import os +import platform +import subprocess +import sys + +from typing import Optional + + +def _get_builtin_tool(exe_name: str) -> Optional[str]: + if platform.system() == "Windows": + exe_name = exe_name + ".exe" + this_path = os.path.dirname(__file__) + tool_path = os.path.join(this_path, "..", "..", "_mlir_libs", exe_name) + return tool_path + + +def main(args=None): + if args is None: + args = sys.argv[1:] + exe = _get_builtin_tool("torch-mlir-opt") + return subprocess.call(args=[exe] + args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 3424cb46aad1..7089e150ff9a 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -1b7523fbe9d0a0c81930673f4374c6e69fa293b6 +5f7ce38e44791817d326467813e354fde1d01db0 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 7b73c61f4e13..a3531096e833 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,7 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://download.pytorch.org/whl/nightly/cpu/torch/ +# The nightly wheels for pytorch are regularly deleted and we don't bump the +# versions at the same pace. The wheels will therefore be cached on the xilinx +# release page, and we use this page as an additional source for the wheels. +-f https://xilinx.github.io/torch-mlir/package-index/ --pre -torch==2.4.0.dev20240505 +torch==2.7.0.dev20250310 diff --git a/requirements.txt b/requirements.txt index f346b53da470..6c86e58ae9c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ --r pytorch-requirements.txt -r build-requirements.txt +-r pytorch-requirements.txt +-r torchvision-requirements.txt -r test-requirements.txt diff --git a/setup.py b/setup.py index 6f5f5d5d1c3b..b04a15004e76 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ # that here, and just package up its contents. import os import pathlib +import platform import shutil import subprocess import sys @@ -75,7 +76,7 @@ def _check_env_flag(name: str, default=None) -> bool: # If true, enable LTC build by default TORCH_MLIR_ENABLE_LTC = _check_env_flag("TORCH_MLIR_ENABLE_LTC", True) TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = _check_env_flag( - "TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS", False + "TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS", True ) LLVM_INSTALL_DIR = os.getenv("LLVM_INSTALL_DIR", None) SRC_DIR = pathlib.Path(__file__).parent.absolute() @@ -198,6 +199,15 @@ def run(self): shutil.copytree(python_package_dir, target_dir, symlinks=False) + torch_mlir_opt_src = os.path.join(cmake_build_dir, "bin", "torch-mlir-opt") + torch_mlir_opt_dst = os.path.join( + target_dir, "torch_mlir", "_mlir_libs", "torch-mlir-opt" + ) + if platform.system() == "Windows": + torch_mlir_opt_src += ".exe" + torch_mlir_opt_dst += ".exe" + shutil.copy2(torch_mlir_opt_src, torch_mlir_opt_dst, follow_symlinks=False) + class CMakeExtension(Extension): def __init__(self, name, sourcedir=""): @@ -223,13 +233,13 @@ def build_extension(self, ext): EXT_MODULES = [ CMakeExtension("torch_mlir._mlir_libs._torchMlir"), ] -NAME = "torch-mlir-core" +NAME = "torch-mlir" # If building PyTorch extensions, customize. if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: import torch - NAME = "torch-mlir" + NAME = "torch-mlir-ext" INSTALL_REQUIRES.extend( [ f"torch=={torch.__version__}".split("+", 1)[0], @@ -266,7 +276,8 @@ def build_extension(self, ext): }, entry_points={ "console_scripts": [ - "torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main", + "torch-mlir-import-onnx = torch_mlir.tools.import_onnx.__main__:_cli_main", + "torch-mlir-opt = torch_mlir.tools.opt.__main__:main", ], }, zip_safe=False, diff --git a/stable-requirements.txt b/stable-requirements.txt new file mode 100644 index 000000000000..6acd25b582b0 --- /dev/null +++ b/stable-requirements.txt @@ -0,0 +1,3 @@ +--index-url https://download.pytorch.org/whl/cpu +torch==2.5.1+cpu +torchvision==0.20.1+cpu diff --git a/test-requirements.txt b/test-requirements.txt index b21e8dfcd021..42278b3cbcf6 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,5 +1,5 @@ pillow dill multiprocess -onnx==1.15.0 +onnx==1.16.1 mpmath==1.3.0 diff --git a/test/CAPI/torch.c b/test/CAPI/torch.c index d42cf96d554c..3d1308f08b25 100644 --- a/test/CAPI/torch.c +++ b/test/CAPI/torch.c @@ -33,12 +33,12 @@ static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes, bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \ fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \ if (TTT##hasSizes) { \ - fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ + fprintf(stderr, #TTT "Type %s rank: %" PRId64 "\n", testName, \ torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \ torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \ for (int i = 0; i < numSizes; ++i) { \ - fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ + fprintf(stderr, #TTT "Type %s pos %d size: %" PRId64 "\n", testName, i, \ TTT##Sizes[i]); \ } \ } \ diff --git a/test/Conversion/TorchConversionToMLProgram/basic.mlir b/test/Conversion/TorchConversionToMLProgram/basic.mlir index c7fb38e1c5b0..262ada6f283d 100644 --- a/test/Conversion/TorchConversionToMLProgram/basic.mlir +++ b/test/Conversion/TorchConversionToMLProgram/basic.mlir @@ -17,3 +17,16 @@ module { return %seed : i64 } } + +// ----- + +module { + func.func @no_seed_needed(%arg0: tensor<2x3xf32>) -> !torch.vtensor<[2,3],f32> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> + } +} + +// CHECK-NOT: ml_program.global +// CHECK-LABEL: @no_seed_needed +// CHECK-NEXT: torch_c.from_builtin_tensor diff --git a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir index 8ef04d95166e..da2424fc3ba2 100644 --- a/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir +++ b/test/Conversion/TorchConversionToMLProgram/multiple_functions.mlir @@ -11,5 +11,5 @@ module { func.func private @f7() -> i64 } -// CHECK: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor +// CHECK-NOT: ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor // CHECK-NOT: @global_seed diff --git a/test/Conversion/TorchOnnxToTorch/op_wise_version.mlir b/test/Conversion/TorchOnnxToTorch/op_wise_version.mlir new file mode 100644 index 000000000000..f35ecf3aeca5 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/op_wise_version.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s + +// CHECK-LABEL: @test_quantizelinear_opset_16_op_19 +func.func @test_quantizelinear_opset_16_op_19(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 16 : si64} { + // CHECK-NOT: torch.operator + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx_meta.version = 19 : si64} : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> + return %0 : !torch.vtensor<[6],si8> +} + +// ----- + +// CHECK-LABEL: @test_quantizelinear_no_opset_op_19 +func.func @test_quantizelinear_no_opset_op_19(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64} { + // CHECK-NOT: torch.operator + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx_meta.version = 19 : si64} : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> + return %0 : !torch.vtensor<[6],si8> +} diff --git a/test/Conversion/TorchOnnxToTorch/ops/if.mlir b/test/Conversion/TorchOnnxToTorch/ops/if.mlir index 1d95a3f5fc3a..09d3472fdf81 100644 --- a/test/Conversion/TorchOnnxToTorch/ops/if.mlir +++ b/test/Conversion/TorchOnnxToTorch/ops/if.mlir @@ -18,3 +18,24 @@ func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor< } return %0 : !torch.vtensor<[1],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_ifop_cast_shape +// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[?],si64>) +// CHECK-DAG: %[[CAST:.*]] = torch.tensor_static_info_cast %{{.*}} : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64> +// CHECK-DAG: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[?],si64> +// CHECK-DAG: } else { +// CHECK-DAG: %[[SQUEEZE:.*]] = torch.prims.squeeze %arg1, %{{.*}} : !torch.vtensor<[?,1],si64>, !torch.list -> !torch.vtensor<[?],si64> +// CHECK-DAG: torch.prim.If.yield %[[SQUEEZE]] : !torch.vtensor<[?],si64> +func.func @test_ifop_cast_shape(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[?],si64> { + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %2 = torch.operator "onnx.Squeeze"(%arg1, %1) : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],si64> + torch.operator_terminator %2 : !torch.vtensor<[?],si64> + }, { + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<0xsi64>} : () -> !torch.vtensor<[0],si64> + torch.operator_terminator %1 : !torch.vtensor<[0],si64> + } + return %0 : !torch.vtensor<[?],si64> +} diff --git a/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir b/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir index bb1821088d12..1d230e79ebdf 100644 --- a/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir +++ b/test/Conversion/TorchOnnxToTorch/ops/lstm.mlir @@ -16,10 +16,71 @@ // CHECK-DAG: torch.prim.Loop.condition // CHECK-DAG: } // CHECK: } -module { - func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { - %none = torch.constant.none - %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) - return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> - } + +func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias( +// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>, +// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>, +// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>, +// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>) +// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) { +// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK: } +// CHECK: torch.aten.flip +// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) { +// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK: } +// CHECK: torch.aten.flip +// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32> +// CHECK: } + +func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs( +// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>, +// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>, +// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>) +// CHECK: torch.aten.transpose.int +// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) { +// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>): +// CHECK-DAG: torch.aten.select.int +// CHECK-DAG: torch.aten.linear +// CHECK-DAG: torch.aten.sigmoid +// CHECK-DAG: torch.aten.tanh +// CHECK-DAG: torch.prim.Loop.condition +// CHECK: } +// CHECK-DAG: torch.aten.transpose.int +// CHECK-DAG: torch.aten.transpose.int +// CHECK-DAG: torch.aten.transpose.int +// CHECK-DAG: torch.aten.transpose.int +// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32> +// CHECK: } + +func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) + return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32> } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a87ec4f8f43f..5e62efa00cf7 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -735,6 +735,19 @@ func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4 // ----- +// CHECK-LABEL: @test_deform_conv +func.func @test_deform_conv(%arg0: !torch.vtensor<[1,1,7,6],f32>, %arg1: !torch.vtensor<[1,8,6,5],f32>, %arg2: !torch.vtensor<[1,1,2,2],f32>, %arg3: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} { + // CHECK: %[[cstOne:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[mask:.*]] = torch.aten.full %[[sizeList:.*]], %[[cstOne]] + // CHECK-SAME: -> !torch.vtensor<[1,4,6,5],f32> + // CHECK: torch.torchvision.deform_conv2d %arg0, %arg2, %arg1, %[[mask]], %arg3 + // CHECK-SAME: : !torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1,4,6,5],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[1,1,6,5],f32> + %1 = torch.operator "onnx.DeformConv"(%arg0, %arg2, %arg1, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.offset_group = 1 : si64, torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> + return %1 : !torch.vtensor<[1,1,6,5],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_si8 func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> @@ -748,6 +761,19 @@ func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !tor // ----- +// CHECK-LABEL: @test_dequantizelinear_si16 +func.func @test_dequantizelinear_si16(%arg0: !torch.vtensor<[6],si16>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si16>, !torch.vtensor<[],f32>, !torch.vtensor<[],si16>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_ui8 func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> @@ -774,6 +800,22 @@ func.func @test_dequantizelinear_i32(%arg0: !torch.vtensor<[6],si32>, %arg1: !to // ----- +// CHECK-LABEL: @test_dequantizelinear_fp8 +func.func @test_dequantizelinear_fp8(%arg0: !torch.vtensor<[6],f8E4M3FN>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f8E4M3FN> -> !torch.float + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[DTY:.+]] = torch.constant.int 6 + // CHECK: %[[TO:.+]] = torch.aten.to.dtype %arg0, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Scalar %[[TO]], %[[ZP]], %[[ONE]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.Scalar %[[SUB]], %[[SCALE]] + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f8E4M3FN>, !torch.vtensor<[],f32>, !torch.vtensor<[],f8E4M3FN>) -> !torch.vtensor<[6],f32> + return %0 : !torch.vtensor<[6],f32> +} + +// ----- // CHECK-LABEL: @test_div_bcast func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -946,12 +988,12 @@ func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32> func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0_1:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list @@ -969,12 +1011,12 @@ func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32 func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1_1:.*]] = torch.constant.int 1 // CHECK: %[[C1_2:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list @@ -988,16 +1030,135 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, // ----- +// CHECK-LABEL: @test_conv_with_asymmetric_padding +func.func @test_conv_with_asymmetric_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int2_1:.*]] = torch.constant.int 2 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int2]], %[[int2_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[str:.*]] = torch.constant.str "constant" + // CHECK: %[[float0:.*]] = torch.constant.float 0.000 + // CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,7,5],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[1,1,9,7],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,9,7],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + // CHECK: return %[[Conv]] + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [2 : si64, 0 : si64, 0 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_autopad +func.func @test_conv_with_autopad(%arg0: !torch.vtensor<[1,1,12,7],f32>, %arg1: !torch.vtensor<[1,1,2,3],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 0 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 4 + // CHECK: %[[C2_0:.*]] = torch.constant.int 3 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,12,7],f32>, !torch.vtensor<[1,1,2,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,3],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 3 : si64], torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.strides = [4 : si64, 3 : si64]} : (!torch.vtensor<[1,1,12,7],f32>, !torch.vtensor<[1,1,2,3],f32>) -> !torch.vtensor<[1,1,3,3],f32> + return %0 : !torch.vtensor<[1,1,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_autopad_asymmetric +func.func @test_conv_with_autopad_asymmetric(%arg0: !torch.vtensor<[1,1,15,9],f32>, %arg1: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[str:.*]] = torch.constant.str "constant" + // CHECK: %[[float0:.*]] = torch.constant.float 0.000 + // CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,15,9],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[1,1,16,12],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C4_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,16,12],f32>, !torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + // CHECK: return %[[Conv]] + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [4 : si64, 4 : si64], torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.strides = [4 : si64, 4 : si64]} : (!torch.vtensor<[1,1,15,9],f32>, !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_conv_with_autopad_asymmetric_lower +func.func @test_conv_with_autopad_asymmetric_lower(%arg0: !torch.vtensor<[1,1,15,9],f32>, %arg1: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[FakePADS:.*]] = torch.prim.ListConstruct %[[int0]], %[[int0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OGPADS:.*]] = torch.prim.ListConstruct %[[int2]], %[[int1]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[str:.*]] = torch.constant.str "constant" + // CHECK: %[[float0:.*]] = torch.constant.float 0.000 + // CHECK: %[[PrePad:.*]] = torch.aten.pad %arg0, %[[OGPADS]], %[[str]], %[[float0]] : !torch.vtensor<[1,1,15,9],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[1,1,16,12],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C4_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: %[[Conv:.*]] = torch.aten.convolution %[[PrePad]], %arg1, %[[BIAS]], %[[STRIDE]], %[[FakePADS]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,16,12],f32>, !torch.vtensor<[1,1,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + // CHECK: return %[[Conv]] + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [4 : si64, 4 : si64], torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.strides = [4 : si64, 4 : si64]} : (!torch.vtensor<[1,1,15,9],f32>, !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_bias_strides_padding func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C3:.*]] = torch.constant.int 3 // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[C1:.*]] = torch.constant.int 1 // CHECK: %[[C1_0:.*]] = torch.constant.int 1 // CHECK: %[[C2:.*]] = torch.constant.int 2 // CHECK: %[[C2_0:.*]] = torch.constant.int 2 // CHECK: %[[C0:.*]] = torch.constant.int 0 - // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list @@ -1168,6 +1329,78 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc // ----- +// CHECK-LABEL: @test_convtranspose_autopad_same_upper + func.func @test_convtranspose_autopad_same_upper(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_3:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_4:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,6,6],f32> + %4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_UPPER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> + return %4 : !torch.vtensor<[1,2,6,6],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_autopad_same_lower + func.func @test_convtranspose_autopad_same_lower(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_3:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_4:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,6,6],f32> + %4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="SAME_LOWER", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,6,6],f32> + return %4 : !torch.vtensor<[1,2,6,6],f32> + } + +// ----- + +// CHECK-LABEL: @test_convtranspose_autopad_valid + func.func @test_convtranspose_autopad_valid(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "user-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_2:.*]] = torch.constant.int 2 + // CHECK: %[[C0_3:.*]] = torch.constant.int 0 + // CHECK: %[[C0_4:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_3]], %[[C0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,8,8],f32> + %4 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.auto_pad="VALID", torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,4,4],f32>) -> !torch.vtensor<[1,2,8,8],f32> + return %4 : !torch.vtensor<[1,2,8,8],f32> + } + +// ----- + // CHECK-LABEL: @test_batchnorm_epsilon func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FALSE:.*]] = torch.constant.bool false @@ -1192,6 +1425,34 @@ func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: ! // ----- +// CHECK-LABEL: func.func @test_batchnorm_training +func.func @test_batchnorm_training(%arg0: !torch.vtensor<[1,16,27],f32>, %arg1: !torch.vtensor<[16],f32>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[16],f32>, %arg4: !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[MOMENTUM:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6 +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[CST2:.*]] = torch.constant.int 2 +// CHECK: %[[REDUCE_DIMS:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[CURRENT_MEAN:.*]] = torch.aten.mean.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,16,27],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[16],f32> +// CHECK: %[[CURRENT_VAR:.*]] = torch.aten.var.dim %arg0, %[[REDUCE_DIMS]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[16],f32> +// CHECK: %[[MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg3, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[CURR_MEAN_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_MEAN]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_0:.*]] = torch.aten.sub.Tensor %[[MEAN_MUL_MOMENTUM]], %[[CURR_MEAN_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[RUNNING_MEAN:.*]] = torch.aten.add.Tensor %[[VAL_0]], %[[CURRENT_MEAN]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %arg4, %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[CURR_VAR_MUL_MOMENTUM:.*]] = torch.aten.mul.Scalar %[[CURRENT_VAR]], %[[MOMENTUM]] : !torch.vtensor<[16],f32>, !torch.float -> !torch.vtensor<[16],f32> +// CHECK: %[[VAL_1:.*]] = torch.aten.sub.Tensor %[[VAR_MUL_MOMENTUM]], %[[CURR_VAR_MUL_MOMENTUM]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[RUNNING_VAR:.*]] = torch.aten.add.Tensor %[[VAL_1]], %[[CURRENT_VAR]], %[[CST1]] : !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.int -> !torch.vtensor<[16],f32> +// CHECK: %[[Y:.*]] = torch.aten.batch_norm %arg0, %arg1, %arg2, %[[CURRENT_MEAN]], %[[CURRENT_VAR]], %[[FALSE]], %[[MOMENTUM]], %[[EPSILON]], %[[FALSE]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,16,27],f32> +// CHECK: return %[[Y]], %[[RUNNING_MEAN]], %[[RUNNING_VAR]] : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32> + %0:3 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.momentum = 1.000000e+00 : f32, torch.onnx.training_mode = 1 : si64} : (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) -> (!torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[1,16,27],f32>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],f32> +} + +// ----- + // CHECK-LABEL: @test_concat_1d_axis_0 func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list @@ -1419,16 +1680,13 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int - // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 - // CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int - // CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int - // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int - // CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[Im1:.+]] = torch.constant.int -1 + // CHECK-DAG: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[INT1_1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32> @@ -1445,16 +1703,15 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] + // CHECK-NEXT: %[[Im1:.+]] = torch.constant.int -1 // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] - // CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK-NEXT: %[[GE:.+]] = torch.aten.ge.int + // CHECK-NEXT: torch.runtime.assert %[[GE]] // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] - // CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 - // CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]] - // CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]] - // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]] + // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[Im1]], %[[ITEM2]] // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] // CHECK: return %[[EXPAND]] %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> @@ -1521,7 +1778,7 @@ func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, // CHECK-LABEL: @test_elu_default func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.elu %arg0, %float1.000000e00, %float1.000000e00_0, %float1.000000e00_0 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Elu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -1649,8 +1906,8 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m // ----- -// CHECK-LABEL: @dense_constant -func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { +// CHECK-LABEL: @dense_resource_constant +func.func @dense_resource_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { // CHECK: torch.vtensor.literal(dense<[0, 10, 128, 17000]> : tensor<4xsi32>) : !torch.vtensor<[4],si32> %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int32> : tensor<4xsi32>} : () -> !torch.vtensor<[4],si32> // CHECK: torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+01, 1.280000e+02, 1.700000e+04]> : tensor<4xf32>) : !torch.vtensor<[4],f32> @@ -1827,6 +2084,24 @@ func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten // ----- +// CHECK-LABEL: func.func @test_constant_of_shape_arg_input +func.func @test_constant_of_shape_arg_input(%arg0: !torch.vtensor<[2], si64>) -> !torch.vtensor<[?,?], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32> + %0 = "torch.operator"(%arg0) <{name = "onnx.ConstantOfShape"}> : (!torch.vtensor<[2], si64>) -> !torch.vtensor<[?,?], f32> + return %0 : !torch.vtensor<[?,?], f32> +} +// ----- + // CHECK-LABEL: func.func @test_constant_of_shape_dense_float_default func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> @@ -2312,3 +2587,332 @@ func.func @test_hammingwindow(%arg0: !torch.vtensor<[],si32>) -> !torch.vtensor< %0 = torch.operator "onnx.HammingWindow"(%arg0) {torch.onnx.periodic = 1 : si64} : (!torch.vtensor<[],si32>) -> !torch.vtensor<[10],f32> return %0 : !torch.vtensor<[10],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_col2im +func.func @test_col2im(%arg0: !torch.vtensor<[1,5,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[INT1_4:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_4]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_4]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,5,5],f32> + // CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32> + %0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,5,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> + return %0 : !torch.vtensor<[1,1,5,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_col2im_pads +func.func @test_col2im_pads(%arg0: !torch.vtensor<[1,5,15],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT1_4:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[INT1_5:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_5]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_5]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,5,15],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,5,5],f32> + // CHECK: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32> + %0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.pads = [0 : si64, 1 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[1,5,15],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> + return %0 : !torch.vtensor<[1,1,5,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_col2im_dilations +func.func @test_col2im_dilations(%arg0: !torch.vtensor<[1,4,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,6,6],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT5_0:.*]] = torch.constant.int 5 + // CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT5_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[INT1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,4,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,6,6],f32> + // CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,6,6],f32> + %0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.dilations = [1 : si64, 5 : si64]} : (!torch.vtensor<[1,4,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,6,6],f32> + return %0 : !torch.vtensor<[1,1,6,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_col2im_strides +func.func @test_col2im_strides(%arg0: !torch.vtensor<[1,9,4],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[DILATIONSLIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[STRIDESLIST:.*]] = torch.prim.ListConstruct %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[PADLIST:.*]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_3:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[NUMTOTENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_0:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM0:.*]] = torch.aten.item %[[INDEXSELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_1:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_1:.*]] = torch.aten.item %[[INDEXSELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[NUMTOTENSOR_2:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_2:.*]] = torch.aten.index_select %arg1, %[[INT0_2]], %[[NUMTOTENSOR_2]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_2:.*]] = torch.aten.item %[[INDEXSELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[NUMTOTENSOR_3:.*]] = torch.prim.NumToTensor.Scalar %[[INT1_2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[INDEXSELECT_3:.*]] = torch.aten.index_select %arg2, %[[INT0_2]], %[[NUMTOTENSOR_3]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK-DAG: %[[ITEM_3:.*]] = torch.aten.item %[[INDEXSELECT_3]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK-DAG: %[[IMAGESHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM0]], %[[ITEM_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[BLOCKSHAPELIST:.*]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[COL2IM:.*]] = torch.aten.col2im %arg0, %[[IMAGESHAPELIST]], %[[BLOCKSHAPELIST]], %[[DILATIONSLIST]], %[[PADLIST]], %[[STRIDESLIST]] : !torch.vtensor<[1,9,4],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,5,5],f32> + // CHECK-DAG: return %[[COL2IM]] : !torch.vtensor<[1,1,5,5],f32> + %0 = torch.operator "onnx.Col2Im"(%arg0, %arg1, %arg2) {torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,9,4],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[1,1,5,5],f32> + return %0 : !torch.vtensor<[1,1,5,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_center_crop_pad_crop_and_pad +func.func @test_center_crop_pad_crop_and_pad(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[STR:.*]] = torch.constant.str "floor" + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C0_3:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,10,3],f32> + // CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,10,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,10,3],f32> + %0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[10,10,3],f32> + return %0 : !torch.vtensor<[10,10,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_center_crop_pad_crop_axes_chw +func.func @test_center_crop_pad_crop_axes_chw(%arg0: !torch.vtensor<[3,20,8],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[STR:.*]] = torch.constant.str "floor" + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_0]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C1_1]] : !torch.vtensor<[3,20,8],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[3,20,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32> + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_2]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_1]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C1_3:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_3]] : !torch.vtensor<[3,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[SIZE_2]], %[[ITEM_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,10,9],f32> + // CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C2_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[3,10,9],f32>, !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,9],f32> + %0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [1 : si64, 2 : si64]} : (!torch.vtensor<[3,20,8],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,9],f32> + return %0 : !torch.vtensor<[3,10,9],f32> +} + +// ----- + +// CHECK-LABEL: @test_center_crop_pad_crop_negative_axes_hwc +func.func @test_center_crop_pad_crop_negative_axes_hwc(%arg0: !torch.vtensor<[20,8,3],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[STR:.*]] = torch.constant.str "floor" + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0_2]] : !torch.vtensor<[20,8,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB:.*]] = torch.aten.sub.int %[[SIZE]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SUB]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[DIV:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[DIV]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM_0]], %[[ITEM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[ITEM_0]], %[[ADD]], %[[C1]] : !torch.vtensor<[20,8,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,3],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %[[SLICE]], %[[C1_2]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SUB_0:.*]] = torch.aten.sub.int %[[ITEM_1]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SUB_TENSOR_0:.*]] = torch.prim.NumToTensor.Scalar %[[SUB_0]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[DIV_0:.*]] = torch.aten.div.Scalar_mode %[[SUB_TENSOR_0]], %[[C2]], %[[STR]] : !torch.vtensor<[],si64>, !torch.int, !torch.str -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[DIV_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[ITEM_2]], %[[SIZE_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C0_3:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %[[SLICE]], %[[C0_3]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %[[SLICE]], %[[C2_0]] : !torch.vtensor<[?,?,3],f32>, !torch.int -> !torch.int + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[SIZE_1]], %[[ITEM_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.*]] = torch.aten.zeros %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[10,9,3],f32> + // CHECK: torch.aten.slice_scatter %[[ZEROS]], %[[SLICE]], %[[C1_0]], %[[ITEM_2]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[10,9,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[10,9,3],f32> + %0 = torch.operator "onnx.CenterCropPad"(%arg0, %arg1) {torch.onnx.axes = [-3 : si64, -2 : si64]} : (!torch.vtensor<[20,8,3],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[10,9,3],f32> + return %0 : !torch.vtensor<[10,9,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_dft_fft +func.func @test_dft_fft(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, + // CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SIG_LEN:.*]] = torch.constant.none + // CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NORM:.*]] = torch.constant.str "backward" + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "constant" + // CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32> + // CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex> + // CHECK: %[[FFT_CMPLX:.*]] = torch.aten.fft_fft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex> + // CHECK: %[[FFT_RES_REAL:.*]] = torch.aten.view_as_real %[[FFT_CMPLX]] : !torch.vtensor<[10,10],complex> -> !torch.vtensor<[10,10,2],f32> + // CHECK: return %[[FFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> + return %0 : !torch.vtensor<[10,10,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_dft_inverse_real +func.func @test_dft_inverse_real(%arg0: !torch.vtensor<[10,10,1],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[INPUT_SIGNAL:.*]]: !torch.vtensor<[10,10,1],f32>, + // CHECK-SAME: %[[AXIS:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "dft_example", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[SIG_LEN:.*]] = torch.constant.none + // CHECK: %[[DIM:.*]] = torch.aten.item %[[AXIS]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NORM:.*]] = torch.constant.str "backward" + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[PAD_DIM_LIST:.*]] = torch.prim.ListConstruct %[[ZERO]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "constant" + // CHECK: %[[INPUT_PADDED:.*]] = torch.aten.pad %[[INPUT_SIGNAL]], %[[PAD_DIM_LIST]], %[[MODE]], %[[FILL_VAL]] : !torch.vtensor<[10,10,1],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[10,10,2],f32> + // CHECK: %[[INPUT_T_CMPLX:.*]] = torch.aten.view_as_complex %[[INPUT_PADDED]] : !torch.vtensor<[10,10,2],f32> -> !torch.vtensor<[10,10],complex> + // CHECK: %[[IFFT_CMPLX:.*]] = torch.aten.fft_ifft %[[INPUT_T_CMPLX]], %[[SIG_LEN]], %[[DIM]], %[[NORM]] : !torch.vtensor<[10,10],complex>, !torch.none, !torch.int, !torch.str -> !torch.vtensor<[10,10],complex> + // CHECK: %[[IFFT_RES_REAL:.*]] = torch.aten.view_as_real %[[IFFT_CMPLX]] : !torch.vtensor<[10,10],complex> -> !torch.vtensor<[10,10,2],f32> + // CHECK: return %[[IFFT_RES_REAL]] : !torch.vtensor<[10,10,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.DFT"(%arg0, %none, %arg1) {torch.onnx.inverse = 1 : si64} : (!torch.vtensor<[10,10,1],f32>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[10,10,2],f32> + return %0 : !torch.vtensor<[10,10,2],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c0f93864f9ee..b2c718bceace 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -78,7 +78,7 @@ func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 // CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] - // CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32> + // CHECK: %[[RES:.+]] = torch.aten.squeeze.dim %[[ISEL]], %[[AXIS]] : !torch.vtensor<[1,4,5],f32>, !torch.int -> !torch.vtensor<[4,5],f32> // CHECK: return %[[RES]] %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> return %0 : !torch.vtensor<[4,5],f32> @@ -180,11 +180,51 @@ func.func @test_gather_nd_1D_indices(%arg0: !torch.vtensor<[2,6,8,5],f32>, %arg1 // ----- +// CHECK-LABEL: func.func @test_gathernd_example_int32_batch_dim1 +func.func @test_gathernd_example_int32_batch_dim1(%arg0: !torch.vtensor<[2,2,2],si32>, %arg1: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIM0:.+]] = torch.aten.size.int %arg0, %[[INT0]] + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM1:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIM2:.+]] = torch.aten.size.int %arg0, %[[INT2]] + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[B0:.+]] = torch.aten.size.int %arg1, %[[INT0_2]] + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_4:.+]] = torch.constant.int 1 + // CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, %[[INT1_3]], %[[INT0_0]], %[[INT1_4]], %[[INT1_1]] + // CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %[[SLICE]], %[[INT0_0]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[SLICE]], %[[DIM1]], %[[INT1_1]] + // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %[[SLICE]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[B0]], %[[INT1_1]] + // CHECK: %[[VIEW:.+]] = torch.aten.view %[[WHERE]], %[[LIST]] + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[UNSQ:.+]] = torch.aten.unsqueeze %[[VIEW]], %[[INT1_5]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[DIM0]], %[[INT1_1]], %[[DIM2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[UNSQ]], %[[LIST]], %[[FALSE]] + // CHECK: %[[INT1_6:.+]] = torch.constant.int 1 + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT1_5]], %[[EXPAND]], %[[FALSE]] + // CHECK: %[[SQ:.+]] = torch.aten.squeeze.dim %[[GATHER]], %[[INT1_5]] + %none = torch.constant.none + %0 = torch.operator "onnx.GatherND"(%arg0, %arg1) {torch.onnx.batch_dims = 1 : si64} : (!torch.vtensor<[2,2,2],si32>, !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,2],si32> + return %0 : !torch.vtensor<[2,2],si32> +} + +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[DIM:.+]] = torch.aten.size.int %arg0, %[[INT0]] + // CHECK-DAG: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[DIM]], %[[ONE]] + // CHECK-DAG: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[INT0]] + // CHECK-DAG: %[[WHERE:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %arg1, %[[FALSE]] + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %[[WHERE]], %[[FALSE]] %0 = torch.operator "onnx.GatherElements"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -274,6 +314,62 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch. // ----- +// CHECK-LABEL: func.func @test_lppool_2d +func.func @test_lppool_2d(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[NE1:.*]] = torch.aten.mul %[[I2_1]], %[[NE]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[CEIL:.*]] = torch.constant.bool false + // CHECK: %[[CIP:.*]] = torch.constant.bool true + // CHECK: %[[P:.*]] = torch.constant.int 2 + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32,32],f32> -> !torch.vtensor<[1,3,32,32],f32> + // CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32,32],f32>, !torch.int -> !torch.vtensor<[1,3,32,32],f32> + // CHECK: %[[AVG:.*]] = torch.aten.avg_pool2d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]], %[[I1]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,31,31],f32> + // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.pow.Tensor_Scalar %[[AVG]], %[[INVP]] : !torch.vtensor<[1,3,31,31],f32>, !torch.float -> !torch.vtensor<[1,3,31,31],f32> + %0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lppool_1d +func.func @test_lppool_1d(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64} { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[NE:.*]] = torch.aten.mul %[[I2]], %[[I1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[K:.*]] = torch.prim.ListConstruct %[[I2]] : (!torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PAD:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.prim.ListConstruct %[[I1_1]] : (!torch.int) -> !torch.list + // CHECK: %[[CEIL:.*]] = torch.constant.bool false + // CHECK: %[[CIP:.*]] = torch.constant.bool true + // CHECK: %[[P:.*]] = torch.constant.int 2 + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,32],f32> -> !torch.vtensor<[1,3,32],f32> + // CHECK: %[[POW:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[P]] : !torch.vtensor<[1,3,32],f32>, !torch.int -> !torch.vtensor<[1,3,32],f32> + // CHECK: %[[AVG:.*]] = torch.aten.avg_pool1d %[[POW]], %[[K]], %[[STR]], %[[PAD]], %[[CEIL]], %[[CIP]] : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32> + // CHECK: %[[POW_0:.*]] = torch.aten.mul.Scalar %[[AVG]], %[[NE]] : !torch.vtensor<[1,3,31],f32>, !torch.int -> !torch.vtensor<[1,3,31],f32> + // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.pow.Tensor_Scalar %[[POW_0]], %[[INVP]] : !torch.vtensor<[1,3,31],f32>, !torch.float -> !torch.vtensor<[1,3,31],f32> + %0 = torch.operator "onnx.LpPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64]} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> + return %0 : !torch.vtensor<[1,3,31],f32> +} + +// ----- + // CHECK-LABEL : func.func @test_layer_norm func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -310,6 +406,137 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor // ----- +// CHECK-LABEL: func.func @test_lrn_default +func.func @test_lrn_default(%arg0: !torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} { + // CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 9.9999997473787516E-5 + // CHECK-DAG: %[[BETA:.*]] = torch.constant.float 7.500000e-01 + // CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 1.000000e+00 + // CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0 + + // CHECK-DAG: %[[I20:.*]] = torch.constant.int 20 + // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I10:.*]] = torch.constant.int 10 + // CHECK-DAG: %[[I3:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1 + // CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I20]], %[[I1]], %[[I10]], %[[I3]], %[[IMINUS1]] + + // CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]] + + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_3:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I1_2]], %[[I1_3]] + + // CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]] + + // CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I3_2]], %[[I1_4]], %[[I1_5]] + + // CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]] + + // CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]] + + // CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]] + // CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]] + + // CHECK-DAG: %[[I20_2:.*]] = torch.constant.int 20 + // CHECK-DAG: %[[I10_2:.*]] = torch.constant.int 10 + // CHECK-DAG: %[[I3_2:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[I50_2:.+]] = torch.constant.int 50 + // CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I20_2]], %[[I10_2]], %[[I3_2]], %[[I50_2]] + + // CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]] + // CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]] + // CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]] + // CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]] + // CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]] + // CHECK: return %[[OUTPUT]] + %0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.size = 3 : si64} : (!torch.vtensor<[20,10,3,50],f32>) -> !torch.vtensor<[20,10,3,50],f32> + return %0 : !torch.vtensor<[20,10,3,50],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_lrn_with_optionals +func.func @test_lrn_with_optionals(%arg0: !torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> attributes {torch.onnx_meta.opset_version = 17 : si64} { + // CHECK-DAG: %[[TRUE:.+]] = torch.constant.bool true + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[F0:.+]] = torch.constant.float 0.000000e+00 + // CHECK-DAG: %[[ALPHA:.*]] = torch.constant.float 0.0020000000949949026 + // CHECK-DAG: %[[BETA:.*]] = torch.constant.float 0.64999997615814209 + // CHECK-DAG: %[[BIAS:.*]] = torch.constant.float 3.000000e+00 + // CHECK-DAG: %[[INSQ:.*]] = torch.aten.mul.Tensor %arg0, %arg0 + + // CHECK-DAG: %[[I13:.*]] = torch.constant.int 13 + // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I19:.*]] = torch.constant.int 19 + // CHECK-DAG: %[[I100:.+]] = torch.constant.int 100 + // CHECK-DAG: %[[IMINUS1:.+]] = torch.constant.int -1 + // CHECK-DAG: %[[VIEWSHAPE:.*]] = torch.prim.ListConstruct %[[I13]], %[[I1]], %[[I19]], %[[I100]], %[[IMINUS1]] + + // CHECK-DAG: %[[VIEW1:.*]] = torch.aten.view %[[INSQ]], %[[VIEWSHAPE]] + + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_2:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_3:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_4:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I2:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[PADDING:.*]] = torch.prim.ListConstruct %[[I0]], %[[I0_2]], %[[I0_3]], %[[I0_4]], %[[I2]], %[[I2_2]] + + // CHECK-DAG: %[[PADDED:.*]] = torch.aten.constant_pad_nd %[[VIEW1]], %[[PADDING]], %[[F0]] + + // CHECK-DAG: %[[I5:.+]] = torch.constant.int 5 + // CHECK-DAG: %[[I1_4:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_5:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[I5]], %[[I1_4]], %[[I1_5]] + + // CHECK-DAG: %[[I1_6:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_7:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I1_8:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[STRIDES:.*]] = torch.prim.ListConstruct %[[I1_6]], %[[I1_7]], %[[I1_8]] + + // CHECK-DAG: %[[I0_5:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_6:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I0_7:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[POOLPADDING:.*]] = torch.prim.ListConstruct %[[I0_5]], %[[I0_6]], %[[I0_7]] + + // CHECK-DAG: %[[POOL3D:.*]] = torch.aten.avg_pool3d %[[PADDED]], %[[KERNELSIZE]], %[[STRIDES]], %[[POOLPADDING]], %[[FALSE]], %[[TRUE]] + // CHECK-DAG: %[[SQUEEZED:.*]] = torch.aten.squeeze.dim %[[POOL3D]], %[[I1]] + + // CHECK-DAG: %[[I13_2:.*]] = torch.constant.int 13 + // CHECK-DAG: %[[I19_2:.*]] = torch.constant.int 19 + // CHECK-DAG: %[[I100_2:.+]] = torch.constant.int 100 + // CHECK-DAG: %[[I200_2:.+]] = torch.constant.int 200 + // CHECK-DAG: %[[ISHAPE:.*]] = torch.prim.ListConstruct %[[I13_2]], %[[I19_2]], %[[I100_2]], %[[I200_2]] + + // CHECK-DAG: %[[VIEW2:.*]] = torch.aten.view %[[SQUEEZED]], %[[ISHAPE]] + // CHECK-DAG: %[[POSTALPHA:.*]] = torch.aten.mul.Scalar %[[VIEW2]], %[[ALPHA]] + // CHECK-DAG: %[[POSTBIAS:.*]] = torch.aten.add.Scalar %[[POSTALPHA]], %[[BIAS]], %[[I1]] + // CHECK-DAG: %[[POSTBETA:.*]] = torch.aten.pow.Tensor_Scalar %[[POSTBIAS]], %[[BETA]] + // CHECK-DAG: %[[OUTPUT:.*]] = torch.aten.div.Tensor %arg0, %[[POSTBETA]] + // CHECK: return %[[OUTPUT]] + %none = torch.constant.none + %0 = torch.operator "onnx.LRN"(%arg0) {torch.onnx.alpha = 2.000000e-03 : f32, torch.onnx.beta = 6.500000e-01 : f32, torch.onnx.bias = 3.000000e+00 : f32, torch.onnx.size = 5 : si64} : (!torch.vtensor<[13,19,100,200],f32>) -> !torch.vtensor<[13,19,100,200],f32> + return %0 : !torch.vtensor<[13,19,100,200],f32> +} + +// ----- + // CHECK-LABEL: @test_matmul_2d func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> @@ -375,6 +602,36 @@ func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1: // ----- +// CHECK-LABEL: func.func @test_multinomial_default +func.func @test_multinomial_default(%arg0: !torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3, 1],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_1:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.bool true + // CHECK: %[[VAL_5:.*]] = torch.aten.multinomial %arg0, %[[VAL_2]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,5],f64>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,1],si64> + // CHECK: %[[VAL_6:.*]] = torch.constant.bool false + // CHECK: %[[VAL_7:.*]] = torch.aten.to.dtype %[[VAL_5]], %[[VAL_1]], %[[VAL_6]], %[[VAL_6]], %[[VAL_3]] : !torch.vtensor<[3,1],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,1],si32> + // CHECK: return %[[VAL_7]] : !torch.vtensor<[3,1],si32> + %0 = torch.operator "onnx.Multinomial"(%arg0) : (!torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3,1],si32> + return %0 : !torch.vtensor<[3,1],si32> +} + +// CHECK-LABEL: func.func @test_multinomial_dtype_double_samplenum_4 +func.func @test_multinomial_dtype_double_samplenum_4(%arg0: !torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3, 4],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_1:.*]] = torch.constant.int 7 + // CHECK: %[[VAL_2:.*]] = torch.constant.int 4 + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.bool true + // CHECK: %[[VAL_5:.*]] = torch.aten.multinomial %arg0, %[[VAL_2]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,5],f64>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64> + // CHECK: %[[VAL_6:.*]] = torch.constant.bool false + // CHECK: %[[VAL_7:.*]] = torch.aten.to.dtype %[[VAL_5]], %[[VAL_1]], %[[VAL_6]], %[[VAL_6]], %[[VAL_3]] : !torch.vtensor<[3,4],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64> + // CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f64> + %0 = torch.operator "onnx.Multinomial"(%arg0) {torch.onnx.dtype = 11 : si64, torch.onnx.sample_size = 4 : si64} : (!torch.vtensor<[3,5],f64>) -> !torch.vtensor<[3,4],f64> + return %0 : !torch.vtensor<[3,4],f64> +} + +// ----- + // CHECK-LABEL: func.func @test_maxpool_2d_default func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { // CHECK: %[[I2:.*]] = torch.constant.int 2 @@ -448,8 +705,8 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // CHECK-LABEL: func.func @test_maxpool_pad func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 2 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 1 // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 // CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308 @@ -473,6 +730,86 @@ func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch return %0 : !torch.vtensor<[1,64,56,56],f32> } +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_lower +func.func @test_maxpool_2d_same_lower(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int1]], %[[int0]], %[[int1_0]], %[[int0_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_LOWER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_same_upper +func.func @test_maxpool_2d_same_upper(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1]], %[[int0_0]], %[[int1_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT0:.*]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[FUNC1:.*]] = torch.aten.constant_pad_nd %arg0, %[[list0]], %[[FLOAT0]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,3,33,33],f32> + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int0_3:.*]] = torch.constant.int 0 + // CHECK: %[[int0_4:.*]] = torch.constant.int 0 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int0_3]], %[[int0_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_5:.*]] = torch.constant.int 1 + // CHECK: %[[int1_6:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_5]], %[[int1_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_7:.*]] = torch.constant.int 1 + // CHECK: %[[int1_8:.*]] = torch.constant.int 1 + // CHECK: %[[list4:.*]] = torch.prim.ListConstruct %[[int1_7]], %[[int1_8]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC6:.*]] = torch.aten.max_pool2d %[[FUNC1]], %[[list1]], %[[list3]], %[[list2]], %[[list4]], %[[FALSE]] : !torch.vtensor<[1,3,33,33],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,32,32],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,32,32],f32> + return %0 : !torch.vtensor<[1,3,32,32],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_precomputed_same_upper +func.func @test_maxpool_2d_precomputed_same_upper(%arg0: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64}{ + // CHECK: %[[int3:.*]] = torch.constant.int 3 + // CHECK: %[[int3_0:.*]] = torch.constant.int 3 + // CHECK: %[[list0:.*]] = torch.prim.ListConstruct %[[int3]], %[[int3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int1_1:.*]] = torch.constant.int 1 + // CHECK: %[[list1:.*]] = torch.prim.ListConstruct %[[int1]], %[[int1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int2_2:.*]] = torch.constant.int 2 + // CHECK: %[[list2:.*]] = torch.prim.ListConstruct %[[int2]], %[[int2_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[int1_3:.*]] = torch.constant.int 1 + // CHECK: %[[int1_4:.*]] = torch.constant.int 1 + // CHECK: %[[list3:.*]] = torch.prim.ListConstruct %[[int1_3]], %[[int1_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[FUNC4:.*]] = torch.aten.max_pool2d %arg0, %[[list0]], %[[list2]], %[[list1]], %[[list3]], %[[FALSE]] : !torch.vtensor<[1,1,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,3,3],f32> +%0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.auto_pad = "SAME_UPPER", torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,3,3],f32> +return %0 : !torch.vtensor<[1,1,3,3],f32> +} + // ----- @@ -499,6 +836,71 @@ func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) // ----- +// CHECK-LABEL: func.func @test_maxroipool +func.func @test_maxroipool(%arg0: !torch.vtensor<[8,3,32,32],f32>, %arg1: !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,3,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT5:.*]] = torch.constant.int 5 + // CHECK: %[[INT2_2:.*]] = torch.constant.int 2 + // CHECK: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK: %[[SELECT1:.*]] = torch.aten.select.int %arg1, %[[INT1]], %[[INT0]] : !torch.vtensor<[2,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[CAST1:.*]] = torch.aten._cast_Long %[[SELECT1]], %[[TRUE]] : !torch.vtensor<[?],f32>, !torch.bool -> !torch.vtensor<[?],si64> + // CHECK: %[[SLICE1:.*]] = torch.aten.slice.Tensor %arg1, %[[INT1]], %[[INT1]], %[[INT5]], %[[INT1]] : !torch.vtensor<[2,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,4],f32> + // CHECK: %[[MUL1:.*]] = torch.aten.mul.Scalar %[[SLICE1]], %[[FLOAT1]] : !torch.vtensor<[?,4],f32>, !torch.float -> !torch.vtensor<[?,4],f32> + // CHECK: %[[CAST2:.*]] = torch.aten._cast_Long %[[MUL1]], %[[TRUE]] : !torch.vtensor<[?,4],f32>, !torch.bool -> !torch.vtensor<[?,4],si64> + // CHECK: %[[INT0_4:.*]] = torch.constant.int 0 + // CHECK: %[[SELECT2:.*]] = torch.aten.select.int %[[CAST2]], %[[INT0]], %[[INT0_4]] : !torch.vtensor<[?,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64> + // CHECK: %[[SELECT3:.*]] = torch.aten.select.int %[[CAST1]], %[[INT0]], %[[INT0_4]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT4:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT4]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT5:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT5]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT6:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT2_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM4:.*]] = torch.aten.item %[[SELECT6]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT7:.*]] = torch.aten.select.int %[[SELECT2]], %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM5:.*]] = torch.aten.item %[[SELECT7]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD1:.*]] = torch.aten.add %[[ITEM4]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD2:.*]] = torch.aten.add %[[ITEM5]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SELECT8:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[ITEM1]] : !torch.vtensor<[8,3,32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,32,32],f32> + // CHECK: %[[SLICE2:.*]] = torch.aten.slice.Tensor %[[SELECT8]], %[[INT1_3]], %[[ITEM3]], %[[ADD2]], %[[INT1]] : !torch.vtensor<[3,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32> + // CHECK: %[[SLICE3:.*]] = torch.aten.slice.Tensor %[[SLICE2]], %[[INT2_2]], %[[ITEM2]], %[[ADD1]], %[[INT1]] : !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32> + // CHECK: %[[RESULT0:.*]], %[[RESULT1:.*]] = torch.aten.adaptive_max_pool2d %[[SLICE3]], %[[LIST0]] : !torch.vtensor<[3,?,?],f32>, !torch.list -> !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],si64> + // CHECK: %[[INT1_5:.*]] = torch.constant.int 1 + // CHECK: %[[SELECT9:.*]] = torch.aten.select.int %[[CAST2]], %[[INT0]], %[[INT1_5]] : !torch.vtensor<[?,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64> + // CHECK: %[[SELECT10:.*]] = torch.aten.select.int %[[CAST1]], %[[INT0]], %[[INT1_5]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM6:.*]] = torch.aten.item %[[SELECT10]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT11:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM7:.*]] = torch.aten.item %[[SELECT11]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT12:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM8:.*]] = torch.aten.item %[[SELECT12]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT13:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT2_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM9:.*]] = torch.aten.item %[[SELECT13]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SELECT14:.*]] = torch.aten.select.int %[[SELECT9]], %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM10:.*]] = torch.aten.item %[[SELECT14]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[ADD3:.*]] = torch.aten.add %[[ITEM9]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD4:.*]] = torch.aten.add %[[ITEM10]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SELECT15:.*]] = torch.aten.select.int %arg0, %[[INT0]], %[[ITEM6]] : !torch.vtensor<[8,3,32,32],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,32,32],f32> + // CHECK: %[[SLICE4:.*]] = torch.aten.slice.Tensor %[[SELECT15]], %[[INT1_3]], %[[ITEM8]], %[[ADD4]], %[[INT1]] : !torch.vtensor<[3,32,32],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32> + // CHECK: %[[SLICE5:.*]] = torch.aten.slice.Tensor %[[SLICE4]], %[[INT2_2]], %[[ITEM7]], %[[ADD3]], %[[INT1]] : !torch.vtensor<[3,?,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,?,?],f32> + // CHECK: %[[RESULT0_6:.*]], %[[RESULT1_7:.*]] = torch.aten.adaptive_max_pool2d %[[SLICE5]], %[[LIST0]] : !torch.vtensor<[3,?,?],f32>, !torch.list -> !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],si64> + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[RESULT0]], %[[RESULT0_6]] : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32>) -> !torch.list> + // CHECK: %[[STACK:.*]] = torch.aten.stack %[[LIST1]], %[[INT0]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,2,2],f32> + // CHECK: return %[[STACK]] : !torch.vtensor<[2,3,2,2],f32> + %0 = torch.operator "onnx.MaxRoiPool"(%arg0, %arg1) {torch.onnx.pooled_shape = [2 : si64, 2 : si64], torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[8,3,32,32],f32>, !torch.vtensor<[2,5],f32>) -> !torch.vtensor<[2,3,2,2],f32> + return %0 : !torch.vtensor<[2,3,2,2],f32> +} + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none" @@ -563,32 +965,40 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t // ----- -// CHECK-LABEL: @test_grid_sampler03 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[B0:.*]] = torch.constant.bool true -// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> -func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> +// CHECK-LABEL: func.func @test_oldest_pad +func.func @test_oldest_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 1 : si64} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32> + %0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.paddings = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> } // ----- -// CHECK-LABEL: func.func @test_less_or_equal -func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> - // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> - // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> - %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> - return %0 : !torch.vtensor<[3,4,5],i1> +// CHECK-LABEL: func.func @test_old_pad +func.func @test_old_pad(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int0_0:.*]] = torch.constant.int 0 + // CHECK: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[int0_0]], %[[int0_1]], %[[int0]], %[[int2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[pad:.*]] = torch.aten.constant_pad_nd %arg0, %[[list]], %[[float0]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: return %[[pad]] : !torch.vtensor<[5,4],f32> + %0 = torch.operator "onnx.Pad"(%arg0) {torch.onnx.mode = "constant", torch.onnx.pads = [0 : si64, 0 : si64, 2 : si64, 0 : si64], torch.onnx.value = 0.000000e+00 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> } // ----- // CHECK-LABEL: func.func @test_pad func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { - // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> @@ -602,9 +1012,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // CHECK: %[[INT3:.+]] = torch.constant.int 3 // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[STR:.+]] = torch.constant.str "constant" - // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32> %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> return %0 : !torch.vtensor<[5,4],f32> @@ -612,12 +1022,36 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // ----- +// CHECK-LABEL: func.func @test_i32pad +func.func @test_i32pad(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PAD:.+]] = torch.aten.constant_pad_nd %arg0, %[[LIST]], %[[VAL]] : !torch.vtensor<[3,4],si32>, !torch.list, !torch.int -> !torch.vtensor<[5,4],si32> + // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],si32> + %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],si32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], si32>) -> !torch.vtensor<[5,4],si32> + return %0 : !torch.vtensor<[5,4],si32> +} + +// ----- + // CHECK-LABEL: @test_pad_optional_constant // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> // CHECK: %[[VAL:.+]] = torch.constant.float 0 -// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" -// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> +// CHECK: torch.aten.constant_pad_nd %[[ARG0]], %{{.*}}, %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[5,4],f32> func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> @@ -626,12 +1060,135 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: ! // ----- +// CHECK-LABEL: @test_pad_wrap +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[VAL:.+]] = torch.constant.none +// CHECK: %[[STR:.+]] = torch.constant.str "circular" +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_wrap(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "wrap"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_pad_edge +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[VAL:.+]] = torch.constant.none +// CHECK: %[[STR:.+]] = torch.constant.str "replicate" +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.none -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_edge(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "edge"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + +func.func @test_center_crop_pad_crop_axes_chw_expanded(%arg0: !torch.vtensor<[4,5],f32>, %arg1: !torch.vtensor<[4],si64>, %arg2: !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD0:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 1 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD1:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 2 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD2:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 3 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[ZERO]], %[[IDX]] + // CHECK: %[[PAD3:.+]] = torch.aten.item %[[SEL]] + + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 2 + + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]] + // CHECK: %[[AXIS0:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]] + + // CHECK: %[[IDX:.+]] = torch.constant.int 1 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg2, %[[ZERO]], %[[IDX]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[ZERO]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[INT]], %[[RANK]] + // CHECK: %[[AXIS1:.+]] = torch.aten.add.int %[[MUL]], %[[ITEM]] + + + // CHECK: %[[AX:.+]] = torch.constant.int 0 + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]] + // CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]] + + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]] + // CHECK: %[[BEGIN0:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]] + // CHECK: %[[END0:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]] + + // CHECK: %[[AX:.+]] = torch.constant.int 1 + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS0]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD0]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD2]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL0]] + // CHECK: %[[ADD1:.+]] = torch.aten.add.int %[[ZERO]], %[[MUL1]] + + // CHECK: %[[EQ:.+]] = torch.aten.eq.int %[[AXIS1]], %[[AX]] + // CHECK: %[[INT:.+]] = torch.aten.Int.bool %[[EQ]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[INT]], %[[PAD1]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT]], %[[PAD3]] + // CHECK: %[[BEGIN1:.+]] = torch.aten.add.int %[[ADD0]], %[[MUL0]] + // CHECK: %[[END1:.+]] = torch.aten.add.int %[[ADD1]], %[[MUL1]] + + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[BEGIN1]], %[[END1]], %[[BEGIN0]], %[[END0]] + // CHECK: %[[MODE:.+]] = torch.constant.str "constant" + // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[MODE]], %[[NONE]] + %none = torch.constant.none + %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[4,5],f32>, !torch.vtensor<[4],si64>, !torch.none, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_pow - func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> - return %0 : !torch.vtensor<[3,4,5],f32> - } +func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_pow_i32 +func.func @test_pow_i32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],f64> + // CHECK: %[[DTY:.+]] = torch.constant.int 3 + // CHECK: %[[RES:.+]] = torch.aten.to.dtype %[[POW]], %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[RES]] + %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> + return %0 : !torch.vtensor<[3,4,5],si32> +} // ----- @@ -639,21 +1196,21 @@ func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: ! func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 - // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32> - // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3],f32>, !torch.float -> !torch.vtensor<[3],f32> + // CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32> // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 - // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[F1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[F0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 - // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[F0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32> - %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -664,18 +1221,19 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 - // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 - // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[F1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[F0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 - // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[F0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> @@ -688,18 +1246,19 @@ func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224 // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 5.000000e-01 - // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[ALPHA_MULTI_X:.*]] = torch.aten.mul.Scalar %arg0, %[[ALPHA_FLOAT]] : !torch.vtensor<[3,4,5],f32>, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[F1:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %[[ALPHA_MULTI_X]], %[[BETA_FLOAT]], %[[F1]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 - // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[F1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[F0:.*]] = torch.constant.float 0.000000e+00 // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 - // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[F0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> @@ -743,6 +1302,89 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: @test_globalmaxpool +func.func @test_globalmaxpool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C5_0:.*]] = torch.constant.int 5 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,1,1],f32> + %0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> + return %0 : !torch.vtensor<[1,3,1,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_globalmaxpool_precomputed +func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1,1],f32> + %0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_globallppool +func.func @test_globallppool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[E1:.*]] = torch.aten.mul %[[C5]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[C5_0:.*]] = torch.constant.int 5 + // CHECK: %[[E2:.*]] = torch.aten.mul %[[C5_0]], %[[E1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5,5],f32> -> !torch.vtensor<[1,3,5,5],f32> + // CHECK: %[[CP:.*]] = torch.constant.int 2 + // CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5,5],f32>, !torch.int -> !torch.vtensor<[1,3,5,5],f32> + // CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool2d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[C1]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,3,1,1],f32> + // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.pow.Tensor_Scalar %[[AVGPOOL]], %[[INVP]] : !torch.vtensor<[1,3,1,1],f32>, !torch.float -> !torch.vtensor<[1,3,1,1],f32> + %0 = torch.operator "onnx.GlobalLpPool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> + return %0 : !torch.vtensor<[1,3,1,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_globallppool_1d +func.func @test_globallppool_1d(%arg0: !torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 2 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[E1:.*]] = torch.aten.mul %[[C5]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]] : (!torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]] : (!torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 : !torch.vtensor<[1,3,5],f32> -> !torch.vtensor<[1,3,5],f32> + // CHECK: %[[CP:.*]] = torch.constant.int 2 + // CHECK: %[[POW1:.*]] = torch.aten.pow.Tensor_Scalar %[[ABS]], %[[CP]] : !torch.vtensor<[1,3,5],f32>, !torch.int -> !torch.vtensor<[1,3,5],f32> + // CHECK: %[[AVGPOOL:.*]] = torch.aten.avg_pool1d %[[POW1]], %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]] : !torch.vtensor<[1,3,5],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,1],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.Scalar %[[AVGPOOL]], %[[E1]] : !torch.vtensor<[1,3,1],f32>, !torch.int -> !torch.vtensor<[1,3,1],f32> + // CHECK: %[[INVP:.*]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.pow.Tensor_Scalar %[[MUL]], %[[INVP]] : !torch.vtensor<[1,3,1],f32>, !torch.float -> !torch.vtensor<[1,3,1],f32> + %0 = torch.operator "onnx.GlobalLpPool"(%arg0) : (!torch.vtensor<[1,3,5],f32>) -> !torch.vtensor<[1,3,1],f32> + return %0 : !torch.vtensor<[1,3,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_max_example func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -893,13 +1535,60 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], // ----- -// CHECK-LABEL: func.func @test_nonzero - func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> - %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> - return %0 : !torch.vtensor<[3,4,5],si64> +// CHECK-LABEL: func.func @test_nllloss_ii +func.func @test_nllloss_ii(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.ignore_index = 1 : si64, torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> } +// CHECK-LABEL: func.func @test_nllloss_ii_ignore_default +func.func @test_nllloss_ii_ignore_default(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int -100 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @test_nllloss_ii_reduction_sum +func.func @test_nllloss_ii_reduction_sum(%arg0: !torch.vtensor<[3,5,6,6],f32>, %arg1: !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int -100 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>, !torch.none, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1) {torch.onnx.reduction = "sum"} : (!torch.vtensor<[3,5,6,6],f32>, !torch.vtensor<[3,6,6],si64>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @test_nllloss_iii_reduction_none_ignore_negative +func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor<[3,5,6],f32>, %arg1: !torch.vtensor<[3,6],si64>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_6:.*]], %[[VAL_7:.*]] = torch.aten.nll_loss_forward %arg0, %arg1, %arg2, %[[VAL_5]], %[[VAL_4]] : !torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32>, !torch.vtensor<[],f32> + // CHECK: return %[[VAL_6]] : !torch.vtensor<[],f32> + %0 = torch.operator "onnx.NegativeLogLikelihoodLoss"(%arg0, %arg1, %arg2) {torch.onnx.ignore_index = -1 : si64, torch.onnx.reduction = "none"} : (!torch.vtensor<[3,5,6],f32>, !torch.vtensor<[3,6],si64>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ZERO:.*]] = torch.constant.int 0 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + // CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64> + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> + return %0 : !torch.vtensor<[1,?],si64> +} + // ----- // CHECK-LABEL: func.func @test_or2d @@ -989,3 +1678,659 @@ func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor< %0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_hardmax +func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[AXIS:.+]] = torch.constant.int 1 + // CHECK: %[[FALSE]] = torch.constant.bool false + // CHECK: %[[ARGMAX:.+]] = torch.aten.argmax %arg0, %[[AXIS]], %[[FALSE]] + // CHECK: %[[CLASSES:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ONEHOT:.+]] = torch.aten.one_hot %[[ARGMAX]], %[[CLASSES]] + // CHECK: %[[PERM0:.+]] = torch.constant.int 0 + // CHECK: %[[PERM2:.+]] = torch.constant.int 2 + // CHECK: %[[PERM1:.+]] = torch.constant.int 1 + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[PERM0]], %[[PERM2]], %[[PERM1]] + // CHECK: %[[PERMUTE:.+]] = torch.aten.permute %[[ONEHOT]], %[[LIST]] + // CHECK: return %[[PERMUTE]] + %0 = torch.operator "onnx.Hardmax"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_onehot_negative_indices +func.func @test_onehot_negative_indices(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[10,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ITEM:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[INT:.*]] = torch.aten.Int.Scalar %[[ITEM]] : !torch.float -> !torch.int + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]]= torch.constant.int 1 + // CHECK: %[[LT:.*]] = torch.aten.lt.Scalar %arg0, %[[C0]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3],i1> + // CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg0, %[[INT]], %[[C1]]: !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[3],si64> + // CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg0 : !torch.vtensor<[3],i1>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[3],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg2, %[[C0]], %[[C0]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg2, %[[C0]], %[[C1]]: !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[ONEHOT:.*]] = torch.aten.one_hot %[[WHERE]], %[[INT]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si32> + // CHECK: %[[D0:.+]] = torch.constant.int 1 + // CHECK: %[[D1:.+]] = torch.constant.int 0 + // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %[[ONEHOT]], %[[D1]], %[[D0]] + // CHECK: %[[C11:.*]] = torch.constant.int 11 + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %[[TRANS]], %[[C11]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[?,3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,3],i1> + // CHECK: %[[RESULT:.*]] = torch.aten.where.Scalar %[[DTYPE]], %[[ITEM_1]], %[[ITEM_0]] : !torch.vtensor<[?,3],i1>, !torch.float, !torch.float -> !torch.vtensor<[10,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[10,3],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.OneHot"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[10,3],f32> + return %0 : !torch.vtensor<[10,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_lpnormalization +func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[CST2:.*]] = torch.constant.int 2 + // CHECK: %[[CST2_0:.*]] = torch.constant.int 2 + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]] : (!torch.int) -> !torch.list + // CHECK: %[[NORM:.*]] = torch.aten.norm.ScalarOpt_dim %arg0, %[[CST2_0]], %[[DIMS]], %[[TRUE]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.int, !torch.list, !torch.bool -> !torch.vtensor<[3,4,1,6,7],f32> + // CHECK: %[[OUT:.*]] = torch.aten.div.Tensor %arg0, %[[NORM]] : !torch.vtensor<[3,4,5,6,7],f32>, !torch.vtensor<[3,4,1,6,7],f32> -> !torch.vtensor<[3,4,5,6,7],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5,6,7],f32> + %0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,5,6,7],f32> + return %0 : !torch.vtensor<[3,4,5,6,7],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape +func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 + // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> + %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> + return %0 : !torch.vtensor<[1,1,4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape +func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_1:.*]] = torch.constant.int 4 + // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]], %[[INT4_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]], %[[INT0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_2:.*]] = torch.constant.int 2 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> + // return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32> + %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> + return %0 : !torch.vtensor<[1,1,4,4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_group_normalization +func.func @test_group_normalization(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[EPSILON:.*]] = torch.constant.float 9.9999997473787516E-6 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> + %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> + return %0 : !torch.vtensor<[3,4,2,2],f32> +} + +// ----- + +func.func @test_group_normalization_epsilon(%arg0: !torch.vtensor<[3,4,2,2],f32>, %arg1: !torch.vtensor<[4],f32>, %arg2: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[EPSILON:.*]] = torch.constant.float 0.0099999997764825821 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[RESULT:.*]] = torch.aten.group_norm %arg0, %int2, %arg1, %arg2, %[[EPSILON]], %[[FALSE:.*]] : !torch.vtensor<[3,4,2,2],f32>, !torch.int, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,2,2],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[3,4,2,2],f32> + %0 = torch.operator "onnx.GroupNormalization"(%arg0, %arg1, %arg2) {torch.onnx.epsilon = 0.00999999977 : f32, torch.onnx.num_groups = 2 : si64} : (!torch.vtensor<[3,4,2,2],f32>, !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4,2,2],f32> + return %0 : !torch.vtensor<[3,4,2,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_optional +func.func @test_optional(%arg0: !torch.list>) -> !torch.optional>> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64} { + // CHECK: %[[RESULT:.*]] = torch.derefine %arg0 : !torch.list> to !torch.optional>> + // CHECK: return %[[RESULT]] : !torch.optional>> + %0 = torch.operator "onnx.Optional"(%arg0) : (!torch.list>) -> !torch.optional>> + return %0 : !torch.optional>> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_optional_sequence +func.func @test_optional_get_element_optional_sequence(%arg0: !torch.optional>>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional>> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional>>) -> !torch.list> + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_optional_tensor +func.func @test_optional_get_element_optional_tensor(%arg0: !torch.optional>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RESULT:.*]] = torch.prim.unchecked_cast %arg0 : !torch.optional> -> !torch.vtensor<[4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.optional>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_sequence +func.func @test_optional_get_element_sequence(%arg0: !torch.list>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.list> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.list>) -> !torch.list> + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: @test_optional_get_element_tensor +func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: return %arg0 : !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_empty_none_input +func.func @test_optional_has_element_empty_none_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE_0]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %none = torch.constant.none + %0 = torch.operator "onnx.OptionalHasElement"(%none) : (!torch.none) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_empty_no_input +func.func @test_optional_has_element_empty_no_input() -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"() : () -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_empty_optional_input +func.func @test_optional_has_element_empty_optional_input(%arg0: !torch.optional>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_optional_tensor_input +func.func @test_optional_has_element_optional_tensor_input(%arg0: !torch.optional>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_optional_list_tensor_input +func.func @test_optional_has_element_optional_list_tensor_input(%arg0: !torch.optional>>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.optional>>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_tensor_input +func.func @test_optional_has_element_tensor_input(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: @test_optional_has_element_list_tensor_input +func.func @test_optional_has_element_list_tensor_input(%arg0: !torch.list>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[DTYPE:.*]] = torch.constant.int 11 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DATA_LIST:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.tensor %[[DATA_LIST]], %[[DTYPE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[RESULT]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.OptionalHasElement"(%arg0) : (!torch.list>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_loop_forlike +func.func @test_loop_forlike(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],i1>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "loop_example", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[MAX_TRIP_COUNT_INP:.*]]: !torch.vtensor<[],si64>, + // CHECK-SAME: %[[CONDITION_INP:.*]]: !torch.vtensor<[],i1>, + // CHECK-SAME: %[[LCD_1:.*]]: !torch.vtensor<[1],f32> + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[MAX_TRIP_COUNT_INT:.*]] = torch.aten.item %[[MAX_TRIP_COUNT_INP]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[CONDITION_INT:.*]] = torch.aten.item %[[CONDITION_INP]] : !torch.vtensor<[],i1> -> !torch.int + // CHECK: %[[CONDITION_BOOL:.*]] = torch.aten.Bool.int %[[CONDITION_INT]] : !torch.int -> !torch.bool + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[MAX_TRIP_COUNT_INT]], %[[TRUE]], init(%[[LCD_1]]) { + // CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[LCD_1_BODY:.*]]: !torch.vtensor<[1],f32>): + // CHECK: %[[ITER_NUM_T:.*]] = torch.prim.NumToTensor.Scalar %[[ITER_NUM]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[NONE_1:.*]] = torch.constant.none + // CHECK: %[[CLONE_INP_COND:.*]] = torch.aten.clone %[[CONDITION_INP]], %[[NONE_1]] : !torch.vtensor<[],i1>, !torch.none -> !torch.vtensor<[],i1> + // CHECK: %[[CONST_ARR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>) : !torch.vtensor<[5],f32> + // CHECK: %[[ONE_T:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[ONE_0:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_ONE_T:.*]] = torch.aten.add.Tensor %[[ITER_NUM_T]], %[[ONE_T]], %[[ONE_0]] : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ZERO_T:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[ZERO_0:.*]] = torch.constant.int 0 + // CHECK: %[[ITER_NUM_RT:.*]] = torch.aten.unsqueeze %[[ITER_NUM_T]], %[[ZERO_0]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ZERO_1:.*]] = torch.constant.int 0 + // CHECK: %[[ADD_ONE_RT:.*]] = torch.aten.unsqueeze %[[ADD_ONE_T]], %[[ZERO_1]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[NONE_2:.*]] = torch.constant.none + // CHECK: %[[ONE_1:.*]] = torch.constant.int 1 + // CHECK: %[[ONE_SIZE_LIST:.*]] = torch.prim.ListConstruct %[[ONE_1]] : (!torch.int) -> !torch.list + // CHECK: %[[ONES_T:.*]] = torch.aten.ones %[[ONE_SIZE_LIST]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]], %[[NONE_2]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si64> + // CHECK: %[[ZERO_2:.*]] = torch.constant.int 0 + // CHECK: %[[ZERO_3:.*]] = torch.constant.int 0 + // CHECK: %[[ZERO_T_1:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO_3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITER_NUM_INDEXED:.*]] = torch.aten.index_select %[[ITER_NUM_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITER_NUM_INT:.*]] = torch.aten.item %[[ITER_NUM_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INC_INDEXED:.*]] = torch.aten.index_select %[[ADD_ONE_RT]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[INC_INT:.*]] = torch.aten.item %[[INC_INDEXED]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_INDEX_T:.*]] = torch.aten.index_select %[[ONES_T]], %[[ZERO_2]], %[[ZERO_T_1]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[INDEX_INT:.*]] = torch.aten.item %[[SLICE_INDEX_T]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INPUT_SLICE:.*]] = torch.aten.slice.Tensor %[[CONST_ARR]], %[[ZERO_3]], %[[ITER_NUM_INT]], %[[INC_INT]], %[[INDEX_INT]] : !torch.vtensor<[5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[ONE_2:.*]] = torch.constant.int 1 + // CHECK: %[[INTERM_RES:.*]] = torch.aten.add.Tensor %[[LCD_1_BODY]], %[[INPUT_SLICE]], %[[ONE_2]] : !torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[INTERM_RES]] : !torch.vtensor<[1],f32>) + // CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + // CHECK: return %[[LOOP]] : !torch.vtensor<[1],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.Loop"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],i1>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> { + ^bb0(%arg3: !torch.vtensor<[],si64>, %arg4: !torch.vtensor<[],i1>, %arg5: !torch.vtensor<[1],f32>): + %1 = torch.operator "onnx.Identity"(%arg4) : (!torch.vtensor<[],i1>) -> !torch.vtensor<[],i1> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<5xf32>} : () -> !torch.vtensor<[5],f32> + %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %4 = torch.operator "onnx.Add"(%arg3, %3) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> + %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %6 = torch.operator "onnx.Unsqueeze"(%arg3, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64> + %7 = torch.operator "onnx.Unsqueeze"(%4, %5) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1],si64> + %8 = torch.operator "onnx.Slice"(%2, %6, %7) : (!torch.vtensor<[5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?],f32> + %9 = torch.operator "onnx.Add"(%arg5, %8) : (!torch.vtensor<[1],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1, %9 : !torch.vtensor<[],i1>, !torch.vtensor<[1],f32> + } + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_nonmaxsuppression_identical_boxes( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,10,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,10],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4],f32>, %arg1: !torch.vtensor<[1,1,10],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,10,4],f32>, !torch.int -> !torch.vtensor<[10,4],f32> + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,10],f32>, !torch.int -> !torch.vtensor<[1,10],f32> + // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32> + // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> + %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> + return %0 : !torch.vtensor<[1,3],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_nonmaxsuppression_single_box( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} +func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> + // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" + // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_37:.*]] = torch.constant.none + // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> + %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> + return %0 : !torch.vtensor<[1,3],si64> +} + +// CHECK-LABEL: func.func @test_nonmaxsuppression_center_point_box( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_nonmaxsuppression_center_point_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_5:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32> + // CHECK: %[[VAL_15:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." + // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[VAL_20:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_21:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_22:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_23:.*]] = torch.constant.int 4 + // CHECK: %[[VAL_24:.*]] = torch.constant.float 2.000000e+00 + // CHECK: %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_27:.*]] = torch.aten.div.Scalar %[[VAL_26]], %[[VAL_24]] : !torch.vtensor<[?,2],f32>, !torch.float -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_28:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_29:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32> + // CHECK: %[[VAL_30:.*]] = torch.prim.ListConstruct %[[VAL_28]], %[[VAL_29]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list + // CHECK: %[[VAL_31:.*]] = torch.aten.cat %[[VAL_30]], %[[VAL_21]] : !torch.list, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[VAL_32:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_33:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_34:.*]] = torch.aten.item %[[VAL_33]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_35:.*]] = torch.aten.ge.float %[[VAL_34]], %[[VAL_32]] : !torch.float, !torch.float -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_35]], "unimplemented: score_threshold should be <= min(scores)" + // CHECK: %[[VAL_36:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_37:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_38:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[VAL_39:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_40:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_41:.*]] = torch.torchvision.nms %[[VAL_31]], %[[VAL_19]], %[[VAL_39]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> + // CHECK: %[[VAL_42:.*]] = torch.aten.size.int %[[VAL_41]], %[[VAL_36]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_43:.*]] = torch.aten.gt.int %[[VAL_42]], %[[VAL_40]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[VAL_44:.*]] = torch.prim.If %[[VAL_43]] -> (!torch.vtensor<[1],si64>) { + // CHECK: %[[VAL_45:.*]] = torch.aten.slice.Tensor %[[VAL_41]], %[[VAL_36]], %[[VAL_36]], %[[VAL_40]], %[[VAL_37]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[VAL_45]] : !torch.vtensor<[1],si64> + // CHECK: } else { + // CHECK: %[[VAL_46:.*]] = torch.tensor_static_info_cast %[[VAL_41]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> + // CHECK: torch.prim.If.yield %[[VAL_46]] : !torch.vtensor<[1],si64> + // CHECK: } + // CHECK: %[[VAL_47:.*]] = torch.aten.unsqueeze %[[VAL_44]], %[[VAL_37]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> + // CHECK: %[[VAL_48:.*]] = torch.aten.size.int %[[VAL_47]], %[[VAL_36]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_49:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_50:.*]] = torch.prim.ListConstruct %[[VAL_48]], %[[VAL_49]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_51:.*]] = torch.constant.none + // CHECK: %[[VAL_52:.*]] = torch.aten.zeros %[[VAL_50]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_53:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_47]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list + // CHECK: %[[VAL_54:.*]] = torch.aten.cat %[[VAL_53]], %[[VAL_37]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_54]] : !torch.vtensor<[1,3],si64> + %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 1 : si64} : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> + return %0 : !torch.vtensor<[1,3],si64> +} +// ----- + +// CHECK-LABEL: func.func @test_mwm +func.func @test_mwm(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>, %arg3: !torch.vtensor<[],f32>, %arg4: !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "test_mwm", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[NUM_MEL_BINS_ARG:.*]]: !torch.vtensor<[],si64>, %[[DFT_LENGTH_ARG:.*]]: !torch.vtensor<[],si64>, %[[SAMPLE_RATE_ARG:.*]]: !torch.vtensor<[],si64>, + // CHECK-SAME: %[[LOWER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32>, + // CHECK-SAME: %[[UPPER_EDGE_HZ_ARG:.*]]: !torch.vtensor<[],f32> + // CHECK: %[[VAL_5:.*]] = torch.constant.none + // CHECK: %[[NUM_MEL_BINS_ITEM:.*]] = torch.aten.item %[[NUM_MEL_BINS_ARG]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[SAMPLE_RATE_ITEM:.*]] = torch.aten.item %[[SAMPLE_RATE_ARG]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[LOWER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[LOWER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[UPPER_EDGE_HZ_ITEM:.*]] = torch.aten.item %[[UPPER_EDGE_HZ_ARG]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_10:.*]] = torch.constant.none + // CHECK: %[[VAL_11:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_13:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_14:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_15:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_16:.*]] = torch.aten.floor_divide.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_13]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[VAL_17:.*]] = torch.aten.add.Scalar %[[VAL_16]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[NUM_SPECTROGRAM_BINS_ITEM:.*]] = torch.aten.item %[[VAL_17]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL_19:.*]] = torch.constant.float 2.595000e+03 + // CHECK: %[[VAL_20:.*]] = torch.constant.float 7.000000e+02 + // CHECK: %[[VAL_21:.*]] = torch.constant.float 1.000000e+01 + // CHECK: %[[VAL_22:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[CONST_LN_TO_LOG10:.*]] = torch.constant.float 0.43429448190325182 + // CHECK: %[[VAL_24:.*]] = torch.aten.div.float %[[LOWER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[VAL_25:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_24]] : !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_26:.*]] = torch.aten.add.Scalar %[[VAL_25]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_27:.*]] = torch.aten.log %[[VAL_26]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_28:.*]] = torch.aten.mul.Scalar %[[VAL_27]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[LOW_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_28]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_30:.*]] = torch.aten.div.float %[[UPPER_EDGE_HZ_ITEM]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[VAL_31:.*]] = torch.prim.NumToTensor.Scalar %[[VAL_30]] : !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_32:.*]] = torch.aten.add.Scalar %[[VAL_31]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.int, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_33:.*]] = torch.aten.log %[[VAL_32]] : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_34:.*]] = torch.aten.mul.Scalar %[[VAL_33]], %[[CONST_LN_TO_LOG10]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[HIGH_FREQ_MEL:.*]] = torch.aten.mul.Scalar %[[VAL_34]], %[[VAL_19]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_36:.*]] = torch.aten.sub.Tensor %[[HIGH_FREQ_MEL]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_37:.*]] = torch.aten.add.int %[[NUM_MEL_BINS_ITEM]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[MEL_STEP:.*]] = torch.aten.div.Scalar %[[VAL_36]], %[[VAL_37]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + // CHECK: %[[LOW_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[CENTER_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[HIGH_BINS_INIT:.*]] = torch.aten.arange %[[NUM_MEL_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[VAL_42:.*]] = torch.aten.add.Scalar %[[DFT_LENGTH_ARG]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[VAL_43:.*]] = torch.aten.item %[[VAL_42]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL_44:.*]] = torch.constant.bool false + // CHECK: %[[VAL_45:.*]] = torch.aten.mul.Tensor %[[LOW_BINS_INIT]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_46:.*]] = torch.aten.add.Tensor %[[VAL_45]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_47:.*]] = torch.aten.div.Scalar %[[VAL_46]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_48:.*]] = torch.aten.clone %[[VAL_46]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_49:.*]] = torch.aten.fill.Scalar %[[VAL_48]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_50:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_49]], %[[VAL_47]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_51:.*]] = torch.aten.sub.Scalar %[[VAL_50]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_52:.*]] = torch.aten.mul.Scalar %[[VAL_51]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_53:.*]] = torch.aten.mul.Scalar %[[VAL_52]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_54:.*]] = torch.aten.div.Scalar %[[VAL_53]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_55:.*]] = torch.aten.to.dtype %[[VAL_54]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[LOW_BINS:.*]] = torch.aten.unsqueeze %[[VAL_55]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[VAL_57:.*]] = torch.aten.add.Scalar %[[CENTER_BINS_INIT]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32> + // CHECK: %[[VAL_58:.*]] = torch.aten.mul.Tensor %[[VAL_57]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_59:.*]] = torch.aten.add.Tensor %[[VAL_58]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_60:.*]] = torch.aten.div.Scalar %[[VAL_59]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_61:.*]] = torch.aten.clone %[[VAL_59]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_62:.*]] = torch.aten.fill.Scalar %[[VAL_61]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_63:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_62]], %[[VAL_60]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_64:.*]] = torch.aten.sub.Scalar %[[VAL_63]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_65:.*]] = torch.aten.mul.Scalar %[[VAL_64]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_66:.*]] = torch.aten.mul.Scalar %[[VAL_65]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_67:.*]] = torch.aten.div.Scalar %[[VAL_66]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_68:.*]] = torch.aten.to.dtype %[[VAL_67]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[CENTER_BINS:.*]] = torch.aten.unsqueeze %[[VAL_68]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[VAL_70:.*]] = torch.aten.add.Scalar %[[HIGH_BINS_INIT]], %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[8],si32>, !torch.int, !torch.int -> !torch.vtensor<[8],si32> + // CHECK: %[[VAL_71:.*]] = torch.aten.mul.Tensor %[[VAL_70]], %[[MEL_STEP]] : !torch.vtensor<[8],si32>, !torch.vtensor<[],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_72:.*]] = torch.aten.add.Tensor %[[VAL_71]], %[[LOW_FREQ_MEL]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_73:.*]] = torch.aten.div.Scalar %[[VAL_72]], %[[VAL_19]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_74:.*]] = torch.aten.clone %[[VAL_72]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.none -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_75:.*]] = torch.aten.fill.Scalar %[[VAL_74]], %[[VAL_21]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_76:.*]] = torch.aten.pow.Tensor_Tensor %[[VAL_75]], %[[VAL_73]] : !torch.vtensor<[8],f32>, !torch.vtensor<[8],f32> -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_77:.*]] = torch.aten.sub.Scalar %[[VAL_76]], %[[VAL_12]], %[[VAL_12]] : !torch.vtensor<[8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_78:.*]] = torch.aten.mul.Scalar %[[VAL_77]], %[[VAL_20]] : !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_79:.*]] = torch.aten.mul.Scalar %[[VAL_78]], %[[VAL_43]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_80:.*]] = torch.aten.div.Scalar %[[VAL_79]], %[[SAMPLE_RATE_ITEM]] : !torch.vtensor<[8],f32>, !torch.int -> !torch.vtensor<[8],f32> + // CHECK: %[[VAL_81:.*]] = torch.aten.to.dtype %[[VAL_80]], %[[VAL_14]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8],si32> + // CHECK: %[[HIGH_BINS:.*]] = torch.aten.unsqueeze %[[VAL_81]], %[[VAL_11]] : !torch.vtensor<[8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[IOTA_INIT:.*]] = torch.aten.arange %[[NUM_SPECTROGRAM_BINS_ITEM]], %[[VAL_14]], %[[VAL_10]], %[[VAL_10]], %[[VAL_10]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[9],si32> + // CHECK: %[[IOTA:.*]] = torch.aten.unsqueeze %[[IOTA_INIT]], %[[VAL_12]] : !torch.vtensor<[9],si32>, !torch.int -> !torch.vtensor<[9,1],si32> + // CHECK: %[[LOW_TO_CENTER:.*]] = torch.aten.sub.Tensor %[[CENTER_BINS]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[CENTER_TO_HIGH:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[CENTER_BINS]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],si32> + // CHECK: %[[VAL_87:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[VAL_88:.*]] = torch.constant.none + // CHECK: %[[VAL_89:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_90:.*]] = torch.aten.full %[[VAL_87]], %[[VAL_12]], %[[VAL_89]], %[[VAL_88]], %[[VAL_88]], %[[VAL_88]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[VAL_91:.*]] = torch.aten.maximum %[[VAL_90]], %[[LOW_TO_CENTER]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32> + // CHECK: %[[UP_SCALE:.*]] = torch.aten.to.dtype %[[VAL_91]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32> + // CHECK: %[[VAL_93:.*]] = torch.aten.maximum %[[VAL_90]], %[[CENTER_TO_HIGH]] : !torch.vtensor<[],f32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[1,8],si32> + // CHECK: %[[DOWN_SCALE:.*]] = torch.aten.to.dtype %[[VAL_93]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[1,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,8],f32> + // CHECK: %[[VAL_95:.*]] = torch.aten.sub.Tensor %[[IOTA]], %[[LOW_BINS]], %[[VAL_12]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[9,8],si32> + // CHECK: %[[VAL_96:.*]] = torch.aten.to.dtype %[[VAL_95]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> + // CHECK: %[[RAMP_UP:.*]] = torch.aten.div.Tensor %[[VAL_96]], %[[UP_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_98:.*]] = torch.aten.sub.Tensor %[[HIGH_BINS]], %[[IOTA]], %[[VAL_12]] : !torch.vtensor<[1,8],si32>, !torch.vtensor<[9,1],si32>, !torch.int -> !torch.vtensor<[9,8],si32> + // CHECK: %[[VAL_99:.*]] = torch.aten.to.dtype %[[VAL_98]], %[[VAL_15]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> + // CHECK: %[[RAMP_DOWN:.*]] = torch.aten.div.Tensor %[[VAL_99]], %[[DOWN_SCALE]] : !torch.vtensor<[9,8],f32>, !torch.vtensor<[1,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_101:.*]] = torch.aten.ge.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[VAL_102:.*]] = torch.aten.eq.Tensor %[[IOTA]], %[[CENTER_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[VAL_103:.*]] = torch.aten.lt.Tensor %[[IOTA]], %[[LOW_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[VAL_104:.*]] = torch.aten.gt.Tensor %[[IOTA]], %[[HIGH_BINS]] : !torch.vtensor<[9,1],si32>, !torch.vtensor<[1,8],si32> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[RAMP_INIT:.*]] = torch.aten.where.self %[[VAL_101]], %[[RAMP_DOWN]], %[[RAMP_UP]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[9,8],f32>, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_106:.*]] = torch.aten.where.ScalarSelf %[[VAL_103]], %[[VAL_11]], %[[RAMP_INIT]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_107:.*]] = torch.aten.where.ScalarSelf %[[VAL_104]], %[[VAL_11]], %[[VAL_106]] : !torch.vtensor<[9,8],i1>, !torch.int, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_108:.*]] = torch.aten.eq.Scalar %[[CENTER_TO_HIGH]], %[[VAL_11]] : !torch.vtensor<[1,8],si32>, !torch.int -> !torch.vtensor<[1,8],i1> + // CHECK: %[[CORNER_CASES:.*]] = torch.aten.logical_and %[[VAL_102]], %[[VAL_108]] : !torch.vtensor<[9,8],i1>, !torch.vtensor<[1,8],i1> -> !torch.vtensor<[9,8],i1> + // CHECK: %[[RAMP:.*]] = torch.aten.where.ScalarSelf %[[CORNER_CASES]], %[[VAL_22]], %[[VAL_107]] : !torch.vtensor<[9,8],i1>, !torch.float, !torch.vtensor<[9,8],f32> -> !torch.vtensor<[9,8],f32> + // CHECK: %[[VAL_111:.*]] = torch.constant.int 6 + // CHECK: %[[OUTPUT:.*]] = torch.aten.to.dtype %[[RAMP]], %[[VAL_111]], %[[VAL_44]], %[[VAL_44]], %[[VAL_10]] : !torch.vtensor<[9,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[9,8],f32> + // CHECK: return %[[OUTPUT]] : !torch.vtensor<[9,8],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.MelWeightMatrix"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[9,8],f32> + return %0 : !torch.vtensor<[9,8],f32> +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8322d3df6602..d8a9c7f23537 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -47,6 +47,23 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch // ----- +// CHECK-LABEL: @test_quantizelinear_f8 +func.func @test_quantizelinear_f8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[6],f8E4M3FN> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[DTYPE:.+]] = torch.constant.int 24 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[ONE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[DIV:.+]] = torch.aten.div.Scalar %arg0, %[[SCALE]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[DIV]], %[[ZP]], %[[ONE]] + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %[[ADD]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[6],f8E4M3FN> + return %0 : !torch.vtensor<[6],f8E4M3FN> +} + +// ----- + // CHECK-LABEL: @test_qlinearconv_nobias func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> @@ -60,12 +77,12 @@ func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] @@ -99,12 +116,12 @@ func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !t // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] @@ -215,10 +232,27 @@ func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f3 // ----- +// CHECK-LABEL: func.func @test_scatter +func.func @test_scatter(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[RESULT:.*]] = torch.aten.scatter.src %arg0, %[[INT0]], %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.Scatter"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> + return %0 : !torch.vtensor<[3,3],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_axis func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32> -> !torch.vtensor<[1,5],f32> + // CHECK: %[[AXIS:.*]] = torch.constant.int 1 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] + // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 + // CHECK: torch.aten.scatter.src %arg0, %[[AXIS]], %[[WHERE]], %arg2 %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } @@ -227,9 +261,16 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar // CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[STR:.*]] = torch.constant.str "add" - // CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[FIVE:.*]] = torch.constant.int 1 +// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1> +// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64> +// CHECK: %[[STR:.*]] = torch.constant.str "sum" +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } @@ -238,8 +279,14 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1 // CHECK-LABEL: func.func @test_scatter_elements_without_axis func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> + // CHECK: %[[AXIS:.*]] = torch.constant.int 0 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[CMP:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] + // CHECK: %[[WHERE:.+]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 + // CHECK: torch.aten.scatter.src %arg0, %[[AXIS]], %[[WHERE]], %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> return %0 : !torch.vtensor<[3,3],f32> } @@ -248,9 +295,16 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, // CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[STR:.*]] = torch.constant.str "multiply" - // CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[ZERO:.*]] = torch.constant.int 0 +// CHECK: %[[FIVE:.*]] = torch.constant.int 1 +// CHECK: %[[SZ:.*]] = torch.aten.size.int %arg0, %[[AXIS]] : !torch.vtensor<[1,5],f32>, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[FIVE]] : !torch.vtensor<[1,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2],si64> +// CHECK: %[[CMP:.*]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] : !torch.vtensor<[1,2],si64>, !torch.int -> !torch.vtensor<[1,2],i1> +// CHECK: %[[WHERE:.*]] = torch.aten.where.self %[[CMP]], %[[ADD]], %arg1 : !torch.vtensor<[1,2],i1>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],si64> -> !torch.vtensor<[1,2],si64> +// CHECK: %[[STR:.*]] = torch.constant.str "prod" +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.aten.scatter_reduce.two %arg0, %[[AXIS]], %[[WHERE]], %arg2, %[[STR]], %[[TRUE]] : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str, !torch.bool -> !torch.vtensor<[1,5],f32> %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> return %0 : !torch.vtensor<[1,5],f32> } @@ -470,6 +524,18 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_unsqueeze_dyn_dims +func.func @test_unsqueeze_dyn_dims(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} { + // CHECK: %[[x0:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[x1:.*]] = torch.aten.unsqueeze %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,1,?],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %1 = torch.operator "onnx.Unsqueeze"(%arg0, %0) : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,1,?],f32> + return %1 : !torch.vtensor<[?,1,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_0 func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -580,6 +646,21 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to // ----- +// CHECK-LABEL: func.func @test_softsign +func.func @test_softsign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[RES:.+]] = torch.aten.add.Scalar %[[ABS]], %[[INT1]], %[[INT1]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[SCALE_T:.*]] = torch.aten.div.Tensor %arg0, %[[RES]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: return %[[SCALE_T]] : !torch.vtensor<[3,4,5],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.Softsign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 @@ -594,7 +675,7 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: func.func @test_reduce_max_empty_set_fp func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0xFFF0000000000000 // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 @@ -610,7 +691,7 @@ func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg // CHECK-LABEL: func.func @test_reduce_max_empty_set_int func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INF:.+]] = torch.constant.int -2147483648 // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 @@ -626,17 +707,8 @@ func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_max_bool_inputs func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1> @@ -648,17 +720,8 @@ func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1> @@ -670,19 +733,9 @@ func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 // CHECK-LABEL: func.func @test_reduce_max_all_dims_default func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I0:.+]] = torch.constant.int 0 - // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MAX]] : !torch.vtensor<[],i1> @@ -694,13 +747,7 @@ func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] @@ -712,9 +759,12 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens // CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -764,8 +814,11 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f // CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE_0:.+]] = torch.constant.bool true // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -863,7 +916,10 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2 // CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -911,10 +967,113 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example +func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f64> -> !torch.vtensor<[1,1,1],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[1,1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded +func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE_1:.+]] = torch.constant.bool false + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f64> -> !torch.vtensor<[3,2],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example +func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example +func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -962,15 +1121,17 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor // ----- // CHECK-LABEL: func.func @test_reduce_sum_keepdims_example -func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_1:.*]] = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %[[VAL_1]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[DIM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %arg1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[3,1,2],f32> } @@ -997,7 +1158,10 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1205,17 +1369,8 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> @@ -1227,17 +1382,8 @@ func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[IDX:.+]] = torch.constant.int 0 - // CHECK: %[[SZ:.+]] = torch.constant.int 0 - // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> @@ -1251,17 +1397,7 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[I0:.+]] = torch.constant.int 0 // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[C0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> @@ -1273,13 +1409,7 @@ func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int - // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] @@ -1386,6 +1516,30 @@ func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_split_2d_split_no_num_outputs( +// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-DAG: %[[DIM:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[SPLITS:.+]] = torch.constant.int 3 +// CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 +// CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 +// CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[DIM]] +// CHECK-DAG: %[[ADD:.+]] = torch.aten.add.int %[[SZ1]], %[[SPLITS]] +// CHECK-DAG: %[[SUB:.+]] = torch.aten.sub.int %[[ADD]], %[[ONE]] +// CHECK-DAG: %[[SLICESZ:.+]] = torch.aten.floordiv.int %[[SUB]], %[[SPLITS]] +// CHECK-DAG: %[[START1:.+]] = torch.aten.add.int %[[ZERO]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE0:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[ZERO]], %[[START1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[START2:.+]] = torch.aten.add.int %[[START1]], %[[SLICESZ]] : !torch.int, !torch.int -> !torch.int +// CHECK-DAG: %[[SLICE1:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START1]], %[[START2]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +// CHECK-DAG: %[[SLICE2:.+]] = torch.aten.slice.Tensor %arg0, %[[DIM]], %[[START2]], %[[SZ1]], %[[ONE]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],f32> +// CHECK: return %[[SLICE0]], %[[SLICE1]], %[[SLICE2]] +func.func @test_split_2d_split_no_num_outputs(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +} + // ----- // CHECK-LABEL: func.func @test_tan @@ -1713,6 +1867,57 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2],si32> } +// ----- + +// CHECK-LABEL : func.func @test_tfidfvectorizer_tf_batch_only_bigrams_skip5 + func.func @test_tfidfvectorizer_tf_batch_onlybigrams_skip5(%arg0: !torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK : %[[output_init:.*]] = torch.aten.zeros %[[x0:.*]], %[[none_0:.*]], %[[none_0]], %[[none_0]], %[[none_0]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,7],f32> + // CHECK : %[[int2_1:.*]] = torch.constant.int 2 + // CHECK : %[[batch_loop:.*]] = torch.prim.Loop %[[int2_1]], %[[true:.*]], init(%[[output_init]]) { + // CHECK : ^bb0(%[[arg1:.*]]: !torch.int, %[[arg2:.*]]: !torch.vtensor<[2,7],f32>): + // CHECK : %[[x3:.*]] = torch.aten.add.int %[[arg1]], %[[int1:.*]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[x4:.*]] = torch.aten.slice.Tensor %arg0, %[[int0:.*]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,6],si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,6],si32> + // CHECK : %[[inputbatch:.*]] = torch.aten.squeeze.dim %[[x4]], %[[int0]] : !torch.vtensor<[1,6],si32>, !torch.int -> !torch.vtensor<[6],si32> + // CHECK : %[[x6:.*]] = torch.aten.slice.Tensor %[[arg2]], %[[int0]], %[[arg1]], %[[x3]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,7],f32> + // CHECK : %[[outputbatch:.*]] = torch.aten.squeeze.dim %[[x6]], %[[int0]] : !torch.vtensor<[1,7],f32>, !torch.int -> !torch.vtensor<[7],f32> + // CHECK : %[[int2_2:.*]] = torch.constant.int 2 + // CHECK : %[[int0_3:.*]] = torch.constant.int 0 + // CHECK : %[[max_skip_count:.*]] = torch.constant.int 6 + // CHECK : %[[skip_loop:.*]] = torch.prim.Loop %[[max_skip_count]], %[[true]], init(%[[int0_3]]) { + // CHECK : ^bb0(%[[arg3:.*]]: !torch.int, %[[arg4:.*]]: !torch.int): + // CHECK : %[[x29:.*]] = torch.aten.add.int %[[arg3]], %[[int1]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[int6_12:.*]] = torch.constant.int 6 + // CHECK : %[[x30:.*]] = torch.aten.sub.int %[[int2_2]], %[[int1]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[x31:.*]] = torch.aten.mul.int %[[x30]], %[[x29]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[x32:.*]] = torch.aten.sub.int %[[int6_12]], %[[x31]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[count_loop:.*]] = torch.prim.Loop %[[x32]], %[[true]], init(%[[arg4]]) { + // CHECK : ^bb0(%[[arg5:.*]]: !torch.int, %[[arg6:.*]]: !torch.int): + // CHECK : %[[input_2gram0:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position0:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> + // CHECK : %[[inputval0:.*]] = torch.aten.item %[[input_2gram0]] : !torch.vtensor<[1],si32> -> !torch.int + // CHECK : %[[eq0:.*]] = torch.aten.eq.int %[[inputval0]], %[[first2gram0:.*]] : !torch.int, !torch.int -> !torch.bool + // CHECK : %[[eq0int:.*]] = torch.aten.Int.bool %[[eq0]] : !torch.bool -> !torch.int + // CHECK : %[[alleq0:.*]] = torch.aten.mul.int %[[eq0int]], %[[int1_13:.*]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[input_2gram1:.*]] = torch.aten.select.int %[[inputbatch]], %[[int0]], %[[position1:.*]] : !torch.vtensor<[6],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> + // CHECK : %[[inputval1:.*]] = torch.aten.item %[[input_2gram1]] : !torch.vtensor<[1],si32> -> !torch.int + // CHECK : %[[eq1:.*]] = torch.aten.eq.int %[[inputval1]], %[[first2gram1:.*]] : !torch.int, !torch.int -> !torch.bool + // CHECK : %[[eq1int:.*]] = torch.aten.Int.bool %[[eq1]] : !torch.bool -> !torch.int + // CHECK : %[[alleq1:.*]] = torch.aten.mul.int %[[eq1int]], %[[alleq0]] : !torch.int, !torch.int -> !torch.int + // CHECK : %[[newcount:.*]] = torch.aten.add.int %[[arg6]], %[[alleq1]] : !torch.int, !torch.int -> !torch.int + // CHECK : torch.prim.Loop.condition %[[true]], iter(%[[newcount]] : !torch.int) + // CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int + // CHECK : torch.prim.Loop.condition %[[true]], iter(%[[skip_loop]] : !torch.int) + // CHECK : } : (!torch.int, !torch.bool, !torch.int) -> !torch.int + // CHECK : %[[count_insert0:.*]] = torch.aten.slice_scatter %[[outputbatch]], %[[counttensor0:.*]], %[[int0]], %[[ngram_indices0:.*]], %[[ngram_indices0plus1:.*]], %[[int1]] : !torch.vtensor<[7],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[7],f32> + // the skip_loop and count_loops repeat for each ngram in the pool_int64t's, then after the last ngram frequency is counted... + // CHECK : %[[unqueezecounts:.*]] = torch.aten.unsqueeze % [[lastcountinsert:.*]], %[[int0]] : !torch.vtensor<[7],f32>, !torch.int -> !torch.vtensor<[1,7],f32> + // CHECK : %[[count_into_output:.*]] = torch.aten.slice_scatter %[[arg2]], %[[unsqueezecounts]], %[[int0]], %[[arg1]], %[[arg1plus1:.*]], %[[int1]] : !torch.vtensor<[2,7],f32>, !torch.vtensor<[1,7],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,7],f32> + // CHECK : torch.prim.Loop.condition %[[true]], iter(%[[count_into_output]] : !torch.vtensor<[2,7],f32>) + // CHECK : } : (!torch.int, !torch.bool, !torch.vtensor<[2,7],f32>) -> !torch.vtensor<[2,7],f32> + // CHECK : return %[[batchloop]] : !torch.vtensor<[2,7],f32> + %0 = torch.operator "onnx.TfIdfVectorizer"(%arg0) {torch.onnx.max_gram_length = 2 : si64, torch.onnx.max_skip_count = 5 : si64, torch.onnx.min_gram_length = 2 : si64, torch.onnx.mode = "TF", torch.onnx.ngram_counts = [0 : si64, 4 : si64], torch.onnx.ngram_indexes = [0 : si64, 1 : si64, 2 : si64, 3 : si64, 4 : si64, 5 : si64, 6 : si64], torch.onnx.pool_int64s = [2 : si64, 3 : si64, 5 : si64, 4 : si64, 5 : si64, 6 : si64, 7 : si64, 8 : si64, 6 : si64, 7 : si64]} : (!torch.vtensor<[2,6],si32>) -> !torch.vtensor<[2,7],f32> + return %0 : !torch.vtensor<[2,7],f32> + } + // ----- // CHECK-LABEL: func.func @test_range_int16_type @@ -1916,6 +2121,40 @@ func.func @test_triu_zero(%arg0: !torch.vtensor<[0,5],si64>, %arg1: !torch.vtens // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest + func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_nearest +func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_sizes_linear + func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], +f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_random_normal func.func @test_random_normal() -> !torch.vtensor<[10],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[I6:.+]] = torch.constant.int 6 @@ -1978,6 +2217,45 @@ func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.v // ----- +// CHECK-LABEL: func.func @test_sequence_construct_3 +module { + func.func @test_sequence_construct_3(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[SEQ:.+]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> +// CHECK: return %[[SEQ]] : !torch.list> + %0 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + return %0 : !torch.list> + } +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_construct_1 +module { + func.func @test_sequence_construct_1(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[SEQ:.+]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[2,3,4],f32>) -> !torch.list> +// CHECK: return %[[SEQ]] : !torch.list> + %0 = torch.operator "onnx.SequenceConstruct"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.list> + return %0 : !torch.list> + } +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_length +module { + func.func @test_sequence_length(%arg0: !torch.list>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int +// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64> +// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64> + %0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list>) -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> + } +} + +// ----- + // CHECK-LABEL: func.func @test_sce_mean_3d func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -2012,18 +2290,1258 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.*]] = torch.constant.str "nearest" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "asymmetric", + torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, + torch.onnx.mode = "nearest", + torch.onnx.nearest_mode = "floor" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest_half_pixel +func.func @test_resize_sizes_nearest_half_pixel(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: %[[MODE_STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.coordinate_transformation_mode = "half_pixel", + torch.onnx.mode = "nearest" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: %[[MODE_STR:.*]] = torch.constant.str "bilinear" + // CHECK: torch.aten.__interpolate.size_list_scale_list + // CHECK-SAME: %[[MODE_STR]] + %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { + torch.onnx.mode = "linear" + } : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK-LABEL: @test_roialign_avg + func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[Dim:.*]] = torch.constant.int 1 + // CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]] + // CHECK: %[[cst6:.*]] = torch.constant.int 6 + // CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]] + // CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1 + // CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]] + // CHECK: %[[Align:.*]] = torch.torchvision.roi_align %arg0, %[[Cat]] + %0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "output_half_pixel", torch.onnx.mode = "avg", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> + return %0 : !torch.vtensor<[30,2,5,5],f32> + } + +// ----- + +// CHECK-LABEL: @test_roialign_max + func.func @test_roialign_max(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[Dim:.*]] = torch.constant.int 1 + // CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]] + // CHECK: %[[cst6:.*]] = torch.constant.int 6 + // CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]] + // CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1 + // CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]] + // CHECK: %[[Pool:.*]], %[[Indices:.*]] = torch.torchvision.roi_pool %arg0, %[[Cat]] + // CHECK: return %[[Pool]] + %0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "max", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> + return %0 : !torch.vtensor<[30,2,5,5],f32> + } + +// ----- + +// CHECK-LABEL: @test_spacetodepth_example +func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,1,4,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,1,4,6],f32>, !torch.list -> !torch.vtensor<[1,1,2,2,3,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[1,1,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,2,1,2,3],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,1,2,3],f32>, !torch.list -> !torch.vtensor<[1,4,2,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,4,2,3],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> + return %0 : !torch.vtensor<[1,4,2,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_spacetodepth +func.func @test_spacetodepth(%arg0: !torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[2,2,6,6],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[2,2,6,6],f32>, !torch.list -> !torch.vtensor<[2,2,3,2,3,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[2,2,3,2,3,2],f32>, !torch.list -> !torch.vtensor<[2,2,2,2,3,3],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[2,2,2,2,3,3],f32>, !torch.list -> !torch.vtensor<[2,8,3,3],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,8,3,3],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[2,2,6,6],f32>) -> !torch.vtensor<[2,8,3,3],f32> + return %0 : !torch.vtensor<[2,8,3,3],f32> +} + +// ----- + +// CHECK-LABEL: @test_spacetodepth +func.func @test_spacetodepth_dynamic_dims(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[B:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[H:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[W:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[H_DIV_BS:.*]] = torch.aten.div.int %[[H]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[W_DIV_BS:.*]] = torch.aten.div.int %[[W]], %[[C2_0]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[H_DIV_BS_INT:.*]] = torch.aten.Int.float %[[H_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[W_DIV_BS_INT:.*]] = torch.aten.Int.float %[[W_DIV_BS]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[B]], %[[C]], %[[H_DIV_BS_INT]], %[[C2_0]], %[[W_DIV_BS_INT]], %[[C2_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,2,?,2],f32> + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[PERMUTE_DIMS:.*]] = torch.prim.ListConstruct %[[C0_0]], %[[C3_0]], %[[C5]], %[[C1_0]], %[[C2_1]], %[[C4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[RESHAPE]], %[[PERMUTE_DIMS]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list -> !torch.vtensor<[?,2,2,?,?,?],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[C]], %[[C4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[B]], %[[MUL]], %[[H_DIV_BS_INT]], %[[W_DIV_BS_INT]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[PERMUTE]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[?,2,2,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32 + %0 = torch.operator "onnx.SpaceToDepth"(%arg0) {torch.onnx.blocksize = 2 : si64} : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @Shrink +func.func @Shrink(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %float1.500000e00 = torch.constant.float 1.500000e+00 + // CHECK: %float1.500000e00_0 = torch.constant.float 1.500000e+00 + // CHECK: %float0.000000e00 = torch.constant.float 0.000000e+00 + // CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 + // CHECK: %float-1.500000e00 = torch.constant.float -1.500000e+00 + // CHECK: %0 = torch.aten.lt.Scalar %arg0, %float-1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],i1> + // CHECK: %1 = torch.aten.add.Scalar %arg0, %float1.500000e00_0, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %2 = torch.aten.sub.Scalar %arg0, %float1.500000e00_0, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %3 = torch.aten.gt.Scalar %arg0, %float1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],i1> + // CHECK: %4 = torch.aten.where.ScalarOther %3, %2, %float0.000000e00 : !torch.vtensor<[5],i1>, !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %5 = torch.aten.where.self %0, %1, %4 : !torch.vtensor<[5],i1>, !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[5],f32> + // CHECK: return %5 : !torch.vtensor<[5],f32> + %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.bias = 1.500000e+00 : f32, torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> + return %0 : !torch.vtensor<[5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_shrink_hard +func.func @test_shrink_hard(%arg0: !torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %float1.500000e00 = torch.constant.float 1.500000e+00 + // CHECK: %float0.000000e00 = torch.constant.float 0.000000e+00 + // CHECK: %float0.000000e00_0 = torch.constant.float 0.000000e+00 + // CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 + // CHECK: %float-1.500000e00 = torch.constant.float -1.500000e+00 + // CHECK: %0 = torch.aten.lt.Scalar %arg0, %float-1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],i1> + // CHECK: %1 = torch.aten.add.Scalar %arg0, %float0.000000e00, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %2 = torch.aten.sub.Scalar %arg0, %float0.000000e00, %float1.000000e00 : !torch.vtensor<[5],f32>, !torch.float, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %3 = torch.aten.gt.Scalar %arg0, %float1.500000e00 : !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],i1> + // CHECK: %4 = torch.aten.where.ScalarOther %3, %2, %float0.000000e00_0 : !torch.vtensor<[5],i1>, !torch.vtensor<[5],f32>, !torch.float -> !torch.vtensor<[5],f32> + // CHECK: %5 = torch.aten.where.self %0, %1, %4 : !torch.vtensor<[5],i1>, !torch.vtensor<[5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[5],f32> + // CHECK: return %5 : !torch.vtensor<[5],f32> + %0 = torch.operator "onnx.Shrink"(%arg0) {torch.onnx.lambd = 1.500000e+00 : f32} : (!torch.vtensor<[5],f32>) -> !torch.vtensor<[5],f32> + return %0 : !torch.vtensor<[5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_at +func.func @test_sequence_at(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_0]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %3 = torch.operator "onnx.SequenceErase"(%2, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %4 = torch.operator "onnx.SequenceAt"(%3, %1) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %4 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_insert +func.func @test_sequence_insert(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-3> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_1:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[VTENSOR_2:.*]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[CONCAT_LIST:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[VTENSOR_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.insert.t %[[CONCAT_LIST]], %[[ITEM_0]], %arg0 : !torch.list>, !torch.int, !torch.vtensor<[2,3,4],f32> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[VTENSOR_2]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[RESULT:.*]] = torch.aten.__getitem__.t %[[CONCAT_LIST]], %[[ITEM_1]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[2,3,4],f32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-3> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + %5 = torch.operator "onnx.SequenceInsert"(%4, %arg0, %1) : (!torch.list>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[],si64>) -> !torch.list> + %6 = torch.operator "onnx.SequenceAt"(%5, %2) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2,3,4],f32> + return %6 : !torch.vtensor<[2,3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_beginning +func.func @test_sequence_erase_at_beginning(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_at_end +func.func @test_sequence_erase_at_end(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_negative_idx +func.func @test_sequence_erase_negative_idx(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<-2> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-2> : tensor} : () -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%3, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_erase_empty +func.func @test_sequence_erase_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[SEQUENCE:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: %[[LENGTH:.*]] = torch.aten.len.t %[[SEQUENCE]] : !torch.list> -> !torch.int + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[VTENSOR]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[CMP:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[COND:.*]] = torch.aten.Int.bool %[[CMP]] : !torch.bool -> !torch.int + // CHECK: %[[OFFSET:.*]] = torch.aten.mul.int %[[COND]], %[[LENGTH]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[POSITION:.*]] = torch.aten.add.int %[[ITEM]], %[[OFFSET]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[NONE_0]], %[[POSITION]], %[[INT1]] : !torch.list>, !torch.none, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[POSITION]], %[[INT1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.t %[[SEQUENCE]], %[[ADD]], %[[LENGTH]], %[[INT1]] : !torch.list>, !torch.int, !torch.int, !torch.int -> !torch.list> + // CHECK: %[[RESULT:.*]] = torch.aten.add.t %[[SLICE]], %[[SLICE_1]] : !torch.list>, !torch.list> -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor} : () -> !torch.vtensor<[],si64> + %1 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + %4 = torch.operator "onnx.SequenceErase"(%1, %0) : (!torch.list>, !torch.vtensor<[],si64>) -> !torch.list> + return %4 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_empty +func.func @test_sequence_empty() -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT6:.*]] = torch.constant.int 6 + // CHECK: %[[SHAPE_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[EMPTY_TENSOR:.*]] = torch.aten.empty.memory_format %[[SHAPE_LIST]], %[[INT6]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[EMPTY_TENSOR]] : (!torch.vtensor<[],f32>) -> !torch.list> + // CHECK: return %[[RESULT]] : !torch.list> + %0 = torch.operator "onnx.SequenceEmpty"() : () -> !torch.list> + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_map_add +func.func @test_sequence_map_add(%arg0: !torch.list>, %arg1: !torch.vtensor<[2,3],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C3]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[2,3],f32>) -> !torch.list> + // CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) { + // CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list>): + // CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list>, !torch.int -> !torch.vtensor<[2,3],f32> + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[SAMPLE]], %arg1, %[[C1]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32> + // CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[ADD]] : !torch.list>, !torch.vtensor<[2,3],f32> -> !torch.list> + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list>) + // CHECK: } : (!torch.int, !torch.bool, !torch.list>) -> !torch.list> + // CHECK: return %[[LOOP]] : !torch.list> + %0 = torch.operator "onnx.SequenceMap"(%arg0, %arg1) : (!torch.list>, !torch.vtensor<[2,3],f32>) -> !torch.list> { + ^bb0(%arg2: !torch.vtensor<[2,3],f32>, %arg3: !torch.vtensor<[2,3],f32>): + %1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> + torch.operator_terminator %1 : !torch.vtensor<[2,3],f32> + } + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_map_add_sequence_variadic +func.func @test_sequence_map_add_sequence_variadic(%arg0: !torch.list>, %arg1: !torch.list>, %arg2: !torch.vtensor<[?],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NEG1:.*]] = torch.constant.int -1 + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[NEG1]] : (!torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[?],f32>) -> !torch.list> + // CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) { + // CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list>): + // CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list>, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[ADDITION_INPUT:.*]] = torch.aten.__getitem__.t %arg1, %[[ITER_NUM]] : !torch.list>, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD:.*]] = torch.aten.add.Tensor %[[SAMPLE]], %[[ADDITION_INPUT]], %[[C1]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_0:.*]] = torch.aten.add.Tensor %[[ADD]], %arg2, %[[C1_0]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + // CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[ADD_0]] : !torch.list>, !torch.vtensor<[?],f32> -> !torch.list> + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list>) + // CHECK: } : (!torch.int, !torch.bool, !torch.list>) -> !torch.list> + // CHECK: return %[[LOOP]] : !torch.list> + %0 = torch.operator "onnx.SequenceMap"(%arg0, %arg1, %arg2) : (!torch.list>, !torch.list>, !torch.vtensor<[?],f32>) -> !torch.list> { + ^bb0(%arg3: !torch.vtensor<[?],f32>, %arg4: !torch.vtensor<[?],f32>, %arg5: !torch.vtensor<[?],f32>): + %1 = torch.operator "onnx.Add"(%arg3, %arg4) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> + %2 = torch.operator "onnx.Add"(%1, %arg5) : (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> + torch.operator_terminator %2 : !torch.vtensor<[?],f32> + } + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_map_identity +func.func @test_sequence_map_identity(%arg0: !torch.list>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NEG1:.*]] = torch.constant.int -1 + // CHECK: %[[NEG1_0:.*]] = torch.constant.int -1 + // CHECK: %[[NEG1_1:.*]] = torch.constant.int -1 + // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[NEG1]], %[[NEG1_0]], %[[NEG1_1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[?,?,?],f32>) -> !torch.list> + // CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) { + // CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list>): + // CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[NONE_0:.*]] = torch.constant.none + // CHECK: %[[CLONE:.*]] = torch.aten.clone %[[SAMPLE]], %[[NONE_0]] : !torch.vtensor<[?,?,?],f32>, !torch.none -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[CLONE]] : !torch.list>, !torch.vtensor<[?,?,?],f32> -> !torch.list> + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list>) + // CHECK: } : (!torch.int, !torch.bool, !torch.list>) -> !torch.list> + // CHECK: return %[[LOOP]] : !torch.list> + %0 = torch.operator "onnx.SequenceMap"(%arg0) : (!torch.list>) -> !torch.list> { + ^bb0(%arg1: !torch.vtensor<[?,?,?],f32>): + %1 = torch.operator "onnx.Identity"(%arg1) : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> + torch.operator_terminator %1 : !torch.vtensor<[?,?,?],f32> + } + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_map_extract_shapes +func.func @test_sequence_map_extract_shapes(%arg0: !torch.list>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[SHAPE]] = torch.prim.ListConstruct %[[C3]] : (!torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[ALLOC:.*]] = torch.aten.empty.memory_format %[[SHAPE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + // CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ALLOC]] : (!torch.vtensor<[3],si64>) -> !torch.list> + // CHECK: %[[LEN:.*]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[LOOP:.*]] = torch.prim.Loop %[[LEN]], %[[TRUE]], init(%[[RESULT]]) { + // CHECK: ^bb0(%[[ITER_NUM:.*]]: !torch.int, %[[SEQ:.*]]: !torch.list>): + // CHECK: %[[SAMPLE:.*]] = torch.aten.__getitem__.t %arg0, %[[ITER_NUM]] : !torch.list>, !torch.int -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[SHAPE_0:.*]] = torch.aten._shape_as_tensor %[[SAMPLE]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[3],si64> + // CHECK: %[[APPEND:.*]] = torch.aten.append.t %[[SEQ]], %[[SHAPE_0]] : !torch.list>, !torch.vtensor<[3],si64> -> !torch.list> + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter(%[[APPEND]] : !torch.list>) + // CHECK: } : (!torch.int, !torch.bool, !torch.list>) -> !torch.list> + // CHECK: return %[[LOOP]] : !torch.list> + %0 = torch.operator "onnx.SequenceMap"(%arg0) : (!torch.list>) -> !torch.list> { + ^bb0(%arg1: !torch.vtensor<[?,?,?],f32>): + %1 = torch.operator "onnx.Shape"(%arg1) : (!torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[3],si64> + torch.operator_terminator %1 : !torch.vtensor<[3],si64> + } + return %0 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_shape_start_1_end_negative_1 +func.func @test_shape_start_1_end_negative_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64} { + // CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT2_0:.+]] = torch.constant.int -1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[SHAPE]], %[[INT0_0]], %[[INT1_0]], %[[INT2_0]], %[[INT1_1]] + %0 = torch.operator "onnx.Shape"(%arg0) {torch.onnx.end = -1 : si64, torch.onnx.start = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_shape_scalar +func.func @test_shape_scalar(%arg0: !torch.vtensor<[],si64> ) -> !torch.vtensor<[?],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.1.0"} { + // CHECK: %[[SHAPE:.+]] = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[],si64> -> !torch.vtensor<[0],si64> + // CHECK: %[[CAST:.+]] = torch.tensor_static_info_cast %[[SHAPE]] : !torch.vtensor<[0],si64> to !torch.vtensor<[?],si64> + %0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[],si64>) -> !torch.vtensor<[?],si64> + return %0: !torch.vtensor<[?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_upsample_nearest +func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "nearest" + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[UPSAMPLE:.*]] = torch.aten.__interpolate.size_list_scale_list %arg0, %[[NONE]], %[[SCALE_LIST:.*]], %[[MODE]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.str, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,1,4,6],f32> + // CHECK: return %[[UPSAMPLE]] : !torch.vtensor<[1,1,4,6],f32> + %0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> + return %0 : !torch.vtensor<[1,1,4,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_upsample_bilinear +func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list + // CHECK: %[[MODE:.*]] = torch.constant.str "bilinear" + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[UPSAMPLE:.*]] = torch.aten.__interpolate.size_list_scale_list %arg0, %[[NONE]], %[[SCALE_LIST:.*]], %[[MODE]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.str, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,1,4,6],f32> + // CHECK: return %[[UPSAMPLE]] : !torch.vtensor<[1,1,4,6],f32> + %0 = torch.operator "onnx.Upsample"(%arg0, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,1,4,6],f32> + return %0 : !torch.vtensor<[1,1,4,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_stft +func.func @test_stft(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list + // CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32> + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]], %[[FALSEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_stft_real_rank2 +func.func @test_stft_real_rank2(%arg0: !torch.vtensor<[1,128],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ONESSHAPE:.*]] = torch.prim.ListConstruct %[[FRAMELEN]] : (!torch.int) -> !torch.list + // CHECK: %[[ONESLIST:.*]] = torch.aten.ones %[[ONESSHAPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],f32> + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %arg0, %[[FRAMELEN]], %[[FRAMESTEP]], %[[FRAMELEN]], %[[ONESLIST]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]], %[[FALSEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[?],f32>, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTELIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTELIST]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %none = torch.constant.none + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %none, %arg2) : (!torch.vtensor<[1,128],f32>, !torch.vtensor<[],si64>, !torch.none, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_stft_with_window +func.func @test_stft_with_window(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.constant.int 16 + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16 + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]], %[[FALSEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_stft_with_window_and_framelen +func.func @test_stft_with_window_and_framelen(%arg0: !torch.vtensor<[1,128,1],f32>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[16],f32>, %arg3: !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FRAMELEN:.*]] = torch.aten.item %arg3 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[FRAMESTEP:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[INT2_0:.*]] = torch.constant.int 2 + // CHECK: %[[SQUEEZE:.*]] = torch.aten.squeeze.dim %arg0, %[[INT2_0]] : !torch.vtensor<[1,128,1],f32>, !torch.int -> !torch.vtensor<[1,128],f32> + // CHECK: %[[WINDOWLEN:.*]] = torch.constant.int 16 + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[STFT:.*]] = torch.aten.stft %[[SQUEEZE]], %[[FRAMELEN]], %[[FRAMESTEP]], %[[WINDOWLEN]], %arg2, %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]], %[[FALSEVAL]] : !torch.vtensor<[1,128],f32>, !torch.int, !torch.int, !torch.int, !torch.vtensor<[16],f32>, !torch.bool, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[1,9,15],complex> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[PERMUTEDIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT2_1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[STFT]], %[[PERMUTEDIMS]] : !torch.vtensor<[1,9,15],complex>, !torch.list -> !torch.vtensor<[1,15,9],complex> + // CHECK: %[[VIEWASREAL:.*]] = torch.aten.view_as_real %[[PERMUTE]] : !torch.vtensor<[1,15,9],complex> -> !torch.vtensor<[1,15,9,2],f32> + // CHECK: return %[[VIEWASREAL]] : !torch.vtensor<[1,15,9,2],f32> + %0 = torch.operator "onnx.STFT"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[1,128,1],f32>, !torch.vtensor<[],si64>, !torch.vtensor<[16],f32>, !torch.vtensor<[],si64>) -> !torch.vtensor<[1,15,9,2],f32> + return %0 : !torch.vtensor<[1,15,9,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_reversesequence_batch +func.func @test_reversesequence_batch(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[C0_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C0_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %[[SLICE]], %[[C1_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %[[SLICE_0]], %[[DIM]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED:.*]] = torch.aten.slice_scatter %[[SLICE]], %[[FLIP]], %[[C1_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[EMBED_0:.*]] = torch.aten.slice_scatter %arg0, %[[EMBED]], %[[C0_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[C1_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.Tensor %[[EMBED_0]], %[[C0_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_2:.*]] = torch.aten.slice.Tensor %[[SLICE_1]], %[[C1_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM_0:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[SLICE_2]], %[[DIM_0]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED_1:.*]] = torch.aten.slice_scatter %[[SLICE_1]], %[[FLIP_0]], %[[C1_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[EMBED_2:.*]] = torch.aten.slice_scatter %[[EMBED_0]], %[[EMBED_1]], %[[C0_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[C2]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_3:.*]] = torch.aten.slice.Tensor %[[EMBED_2]], %[[C0_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX_1:.*]] = torch.prim.NumToTensor.Scalar %[[C2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_1:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_4:.*]] = torch.aten.slice.Tensor %[[SLICE_3]], %[[C1_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM_1:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_1:.*]] = torch.aten.flip %[[SLICE_4]], %[[DIM_1]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED_3:.*]] = torch.aten.slice_scatter %[[SLICE_3]], %[[FLIP_1]], %[[C1_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[EMBED_4:.*]] = torch.aten.slice_scatter %[[EMBED_2]], %[[EMBED_3]], %[[C0_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[ADD_2:.*]] = torch.aten.add.int %[[C3]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_5:.*]] = torch.aten.slice.Tensor %[[EMBED_4]], %[[C0_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: %[[INDEX_2:.*]] = torch.prim.NumToTensor.Scalar %[[C3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_2:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_6:.*]] = torch.aten.slice.Tensor %[[SLICE_5]], %[[C1_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,?],f32> + // CHECK: %[[DIM_2:.*]] = torch.prim.ListConstruct %[[C1_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_2:.*]] = torch.aten.flip %[[SLICE_6]], %[[DIM_2]] : !torch.vtensor<[1,?],f32>, !torch.list -> !torch.vtensor<[1,?],f32> + // CHECK: %[[EMBED_5:.*]] = torch.aten.slice_scatter %[[SLICE_5]], %[[FLIP_2]], %[[C1_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,4],f32> + // CHECK: torch.aten.slice_scatter %[[EMBED_4]], %[[EMBED_5]], %[[C0_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + %0 = torch.operator "onnx.ReverseSequence"(%arg0, %arg1) {torch.onnx.batch_axis = 0 : si64, torch.onnx.time_axis = 1 : si64} : (!torch.vtensor<[4,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} + +// ----- + +// CHECK-LABEL: @test_reversesequence_time +func.func @test_reversesequence_time(%arg0: !torch.vtensor<[4,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[C0_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %arg0, %[[C1_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX:.*]] = torch.prim.NumToTensor.Scalar %[[C0_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %[[SLICE]], %[[C0_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %[[SLICE_0]], %[[DIM]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED:.*]] = torch.aten.slice_scatter %[[SLICE]], %[[FLIP]], %[[C0_0]], %[[C0]], %[[ITEM]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[EMBED_0:.*]] = torch.aten.slice_scatter %arg0, %[[EMBED]], %[[C1_0]], %[[C0_1]], %[[ADD]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[C1_1]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_1:.*]] = torch.aten.slice.Tensor %[[EMBED_0]], %[[C1_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX_0:.*]] = torch.prim.NumToTensor.Scalar %[[C1_1]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_0:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_2:.*]] = torch.aten.slice.Tensor %[[SLICE_1]], %[[C0_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM_0:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[SLICE_2]], %[[DIM_0]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED_1:.*]] = torch.aten.slice_scatter %[[SLICE_1]], %[[FLIP_0]], %[[C0_0]], %[[C0]], %[[ITEM_0]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[EMBED_2:.*]] = torch.aten.slice_scatter %[[EMBED_0]], %[[EMBED_1]], %[[C1_0]], %[[C1_1]], %[[ADD_0]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[C2]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_3:.*]] = torch.aten.slice.Tensor %[[EMBED_2]], %[[C1_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX_1:.*]] = torch.prim.NumToTensor.Scalar %[[C2]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_1:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_1:.*]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_4:.*]] = torch.aten.slice.Tensor %[[SLICE_3]], %[[C0_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM_1:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_1:.*]] = torch.aten.flip %[[SLICE_4]], %[[DIM_1]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED_3:.*]] = torch.aten.slice_scatter %[[SLICE_3]], %[[FLIP_1]], %[[C0_0]], %[[C0]], %[[ITEM_1]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[EMBED_4:.*]] = torch.aten.slice_scatter %[[EMBED_2]], %[[EMBED_3]], %[[C1_0]], %[[C2]], %[[ADD_1]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[ADD_2:.*]] = torch.aten.add.int %[[C3]], %[[C1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[SLICE_5:.*]] = torch.aten.slice.Tensor %[[EMBED_4]], %[[C1_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: %[[INDEX_2:.*]] = torch.prim.NumToTensor.Scalar %[[C3]] : !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[SELECT_2:.*]] = torch.aten.index_select %arg1, %[[C0]], %[[INDEX_2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM_2:.*]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[SLICE_6:.*]] = torch.aten.slice.Tensor %[[SLICE_5]], %[[C0_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,1],f32> + // CHECK: %[[DIM_2:.*]] = torch.prim.ListConstruct %[[C0_0]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP_2:.*]] = torch.aten.flip %[[SLICE_6]], %[[DIM_2]] : !torch.vtensor<[?,1],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: %[[EMBED_5:.*]] = torch.aten.slice_scatter %[[SLICE_5]], %[[FLIP_2]], %[[C0_0]], %[[C0]], %[[ITEM_2]], %[[C1]] : !torch.vtensor<[4,1],f32>, !torch.vtensor<[?,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],f32> + // CHECK: torch.aten.slice_scatter %[[EMBED_4]], %[[EMBED_5]], %[[C1_0]], %[[C3]], %[[ADD_2]], %[[C1]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,4],f32> + %0 = torch.operator "onnx.ReverseSequence"(%arg0, %arg1) {torch.onnx.batch_axis = 1 : si64, torch.onnx.time_axis = 0 : si64} : (!torch.vtensor<[4,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_scatternd( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,1],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scatternd(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.vtensor<[2,1],si64>, %arg2: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_5:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_4]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_6]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_9:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_8]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_13:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_12]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_14:.*]] = torch.aten.mul.int %[[VAL_11]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[VAL_15:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.slice.Tensor %[[VAL_1]], %[[VAL_15]], %[[VAL_10]], %[[VAL_16]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_18:.*]] = torch.aten.lt.Scalar %[[VAL_17]], %[[VAL_10]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],i1> + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Scalar %[[VAL_17]], %[[VAL_5]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_20:.*]] = torch.aten.where.self %[[VAL_18]], %[[VAL_19]], %[[VAL_17]] : !torch.vtensor<[2,1],i1>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,1],si64> -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_21:.*]] = torch.prim.ListConstruct %[[VAL_13]], %[[VAL_11]], %[[VAL_11]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_22:.*]] = torch.aten.view %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[2,1],si64>, !torch.list -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_23:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_24:.*]] = torch.aten.flatten.using_ints %[[VAL_22]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_25:.*]] = torch.aten.flatten.using_ints %[[VAL_2]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],f32> + // CHECK: %[[VAL_26:.*]] = torch.prim.ListConstruct %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_27:.*]] = torch.constant.bool false + // CHECK: %[[VAL_28:.*]] = torch.aten.expand %[[VAL_24]], %[[VAL_26]], %[[VAL_27]] : !torch.vtensor<[2,1,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[2,4,4],si64> + // CHECK: %[[VAL_29:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_30:.*]] = torch.aten.flatten.using_ints %[[VAL_0]], %[[VAL_10]], %[[VAL_29]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32> + // CHECK: %[[VAL_31:.*]] = torch.aten.scatter.src %[[VAL_30]], %[[VAL_10]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.vtensor<[2,4,4],si64>, !torch.vtensor<[2,4,4],f32> -> !torch.vtensor<[4,4,4],f32> + // CHECK: return %[[VAL_31]] : !torch.vtensor<[4,4,4],f32> + // CHECK: } + %none = torch.constant.none + %0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> + return %0 : !torch.vtensor<[4,4,4],f32> +} + +// CHECK-LABEL: func.func @test_scatternd_add( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,1],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scatternd_add(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.vtensor<[2,1],si64>, %arg2: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_5:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_4]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_6]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_9:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_8]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_13:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_12]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_14:.*]] = torch.aten.mul.int %[[VAL_11]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[VAL_15:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.slice.Tensor %[[VAL_1]], %[[VAL_15]], %[[VAL_10]], %[[VAL_16]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_18:.*]] = torch.aten.lt.Scalar %[[VAL_17]], %[[VAL_10]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],i1> + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Scalar %[[VAL_17]], %[[VAL_5]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_20:.*]] = torch.aten.where.self %[[VAL_18]], %[[VAL_19]], %[[VAL_17]] : !torch.vtensor<[2,1],i1>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,1],si64> -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_21:.*]] = torch.prim.ListConstruct %[[VAL_13]], %[[VAL_11]], %[[VAL_11]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_22:.*]] = torch.aten.view %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[2,1],si64>, !torch.list -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_23:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_24:.*]] = torch.aten.flatten.using_ints %[[VAL_22]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_25:.*]] = torch.aten.flatten.using_ints %[[VAL_2]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],f32> + // CHECK: %[[VAL_26:.*]] = torch.prim.ListConstruct %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_27:.*]] = torch.constant.bool false + // CHECK: %[[VAL_28:.*]] = torch.aten.expand %[[VAL_24]], %[[VAL_26]], %[[VAL_27]] : !torch.vtensor<[2,1,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[2,4,4],si64> + // CHECK: %[[VAL_29:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_30:.*]] = torch.aten.flatten.using_ints %[[VAL_0]], %[[VAL_10]], %[[VAL_29]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32> + // CHECK: %[[VAL_31:.*]] = torch.constant.str "sum" + // CHECK: %[[VAL_32:.*]] = torch.constant.bool true + // CHECK: %[[VAL_33:.*]] = torch.aten.scatter_reduce.two %[[VAL_30]], %[[VAL_10]], %[[VAL_28]], %[[VAL_25]], %[[VAL_31]], %[[VAL_32]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.vtensor<[2,4,4],si64>, !torch.vtensor<[2,4,4],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4,4,4],f32> + // CHECK: return %[[VAL_33]] : !torch.vtensor<[4,4,4],f32> + // CHECK: } + %none = torch.constant.none + %0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) {torch.onnx.reduction = "add"} : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> + return %0 : !torch.vtensor<[4,4,4],f32> +} + +// CHECK-LABEL: func.func @test_scatternd_mul( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,1],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scatternd_mul(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.vtensor<[2,1],si64>, %arg2: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_5:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_4]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_6]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_9:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_8]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_13:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_12]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_14:.*]] = torch.aten.mul.int %[[VAL_11]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[VAL_15:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.slice.Tensor %[[VAL_1]], %[[VAL_15]], %[[VAL_10]], %[[VAL_16]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_18:.*]] = torch.aten.lt.Scalar %[[VAL_17]], %[[VAL_10]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],i1> + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Scalar %[[VAL_17]], %[[VAL_5]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_20:.*]] = torch.aten.where.self %[[VAL_18]], %[[VAL_19]], %[[VAL_17]] : !torch.vtensor<[2,1],i1>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,1],si64> -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_21:.*]] = torch.prim.ListConstruct %[[VAL_13]], %[[VAL_11]], %[[VAL_11]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_22:.*]] = torch.aten.view %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[2,1],si64>, !torch.list -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_23:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_24:.*]] = torch.aten.flatten.using_ints %[[VAL_22]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_25:.*]] = torch.aten.flatten.using_ints %[[VAL_2]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],f32> + // CHECK: %[[VAL_26:.*]] = torch.prim.ListConstruct %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_27:.*]] = torch.constant.bool false + // CHECK: %[[VAL_28:.*]] = torch.aten.expand %[[VAL_24]], %[[VAL_26]], %[[VAL_27]] : !torch.vtensor<[2,1,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[2,4,4],si64> + // CHECK: %[[VAL_29:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_30:.*]] = torch.aten.flatten.using_ints %[[VAL_0]], %[[VAL_10]], %[[VAL_29]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32> + // CHECK: %[[VAL_31:.*]] = torch.constant.str "prod" + // CHECK: %[[VAL_32:.*]] = torch.constant.bool true + // CHECK: %[[VAL_33:.*]] = torch.aten.scatter_reduce.two %[[VAL_30]], %[[VAL_10]], %[[VAL_28]], %[[VAL_25]], %[[VAL_31]], %[[VAL_32]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.vtensor<[2,4,4],si64>, !torch.vtensor<[2,4,4],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4,4,4],f32> + // CHECK: return %[[VAL_33]] : !torch.vtensor<[4,4,4],f32> + // CHECK: } + %none = torch.constant.none + %0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) {torch.onnx.reduction = "mul"} : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> + return %0 : !torch.vtensor<[4,4,4],f32> +} + +// CHECK-LABEL: func.func @test_scatternd_max( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,1],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scatternd_max(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.vtensor<[2,1],si64>, %arg2: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_5:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_4]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_6]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_9:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_8]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_13:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_12]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_14:.*]] = torch.aten.mul.int %[[VAL_11]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[VAL_15:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.slice.Tensor %[[VAL_1]], %[[VAL_15]], %[[VAL_10]], %[[VAL_16]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_18:.*]] = torch.aten.lt.Scalar %[[VAL_17]], %[[VAL_10]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],i1> + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Scalar %[[VAL_17]], %[[VAL_5]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_20:.*]] = torch.aten.where.self %[[VAL_18]], %[[VAL_19]], %[[VAL_17]] : !torch.vtensor<[2,1],i1>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,1],si64> -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_21:.*]] = torch.prim.ListConstruct %[[VAL_13]], %[[VAL_11]], %[[VAL_11]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_22:.*]] = torch.aten.view %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[2,1],si64>, !torch.list -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_23:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_24:.*]] = torch.aten.flatten.using_ints %[[VAL_22]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_25:.*]] = torch.aten.flatten.using_ints %[[VAL_2]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],f32> + // CHECK: %[[VAL_26:.*]] = torch.prim.ListConstruct %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_27:.*]] = torch.constant.bool false + // CHECK: %[[VAL_28:.*]] = torch.aten.expand %[[VAL_24]], %[[VAL_26]], %[[VAL_27]] : !torch.vtensor<[2,1,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[2,4,4],si64> + // CHECK: %[[VAL_29:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_30:.*]] = torch.aten.flatten.using_ints %[[VAL_0]], %[[VAL_10]], %[[VAL_29]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32> + // CHECK: %[[VAL_31:.*]] = torch.constant.str "amax" + // CHECK: %[[VAL_32:.*]] = torch.constant.bool true + // CHECK: %[[VAL_33:.*]] = torch.aten.scatter_reduce.two %[[VAL_30]], %[[VAL_10]], %[[VAL_28]], %[[VAL_25]], %[[VAL_31]], %[[VAL_32]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.vtensor<[2,4,4],si64>, !torch.vtensor<[2,4,4],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4,4,4],f32> + // CHECK: return %[[VAL_33]] : !torch.vtensor<[4,4,4],f32> + // CHECK: } + %none = torch.constant.none + %0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) {torch.onnx.reduction = "max"} : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> + return %0 : !torch.vtensor<[4,4,4],f32> +} + +// CHECK-LABEL: func.func @test_scatternd_min( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,1],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scatternd_min(%arg0: !torch.vtensor<[4,4,4],f32>, %arg1: !torch.vtensor<[2,1],si64>, %arg2: !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_3:.*]] = torch.constant.none + // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_5:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_4]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_6:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_6]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_9:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_8]] : !torch.vtensor<[4,4,4],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_10:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_11:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_12:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_13:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_12]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_14:.*]] = torch.aten.mul.int %[[VAL_11]], %[[VAL_13]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[VAL_15:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_16:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_17:.*]] = torch.aten.slice.Tensor %[[VAL_1]], %[[VAL_15]], %[[VAL_10]], %[[VAL_16]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_18:.*]] = torch.aten.lt.Scalar %[[VAL_17]], %[[VAL_10]] : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],i1> + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Scalar %[[VAL_17]], %[[VAL_5]], %[[VAL_11]] : !torch.vtensor<[2,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_20:.*]] = torch.aten.where.self %[[VAL_18]], %[[VAL_19]], %[[VAL_17]] : !torch.vtensor<[2,1],i1>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,1],si64> -> !torch.vtensor<[2,1],si64> + // CHECK: %[[VAL_21:.*]] = torch.prim.ListConstruct %[[VAL_13]], %[[VAL_11]], %[[VAL_11]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_22:.*]] = torch.aten.view %[[VAL_20]], %[[VAL_21]] : !torch.vtensor<[2,1],si64>, !torch.list -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_23:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_24:.*]] = torch.aten.flatten.using_ints %[[VAL_22]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,1,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,1,1],si64> + // CHECK: %[[VAL_25:.*]] = torch.aten.flatten.using_ints %[[VAL_2]], %[[VAL_10]], %[[VAL_23]] : !torch.vtensor<[2,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,4,4],f32> + // CHECK: %[[VAL_26:.*]] = torch.prim.ListConstruct %[[VAL_14]], %[[VAL_7]], %[[VAL_9]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_27:.*]] = torch.constant.bool false + // CHECK: %[[VAL_28:.*]] = torch.aten.expand %[[VAL_24]], %[[VAL_26]], %[[VAL_27]] : !torch.vtensor<[2,1,1],si64>, !torch.list, !torch.bool -> !torch.vtensor<[2,4,4],si64> + // CHECK: %[[VAL_29:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_30:.*]] = torch.aten.flatten.using_ints %[[VAL_0]], %[[VAL_10]], %[[VAL_29]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,4,4],f32> + // CHECK: %[[VAL_31:.*]] = torch.constant.str "amin" + // CHECK: %[[VAL_32:.*]] = torch.constant.bool true + // CHECK: %[[VAL_33:.*]] = torch.aten.scatter_reduce.two %[[VAL_30]], %[[VAL_10]], %[[VAL_28]], %[[VAL_25]], %[[VAL_31]], %[[VAL_32]] : !torch.vtensor<[4,4,4],f32>, !torch.int, !torch.vtensor<[2,4,4],si64>, !torch.vtensor<[2,4,4],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4,4,4],f32> + // CHECK: return %[[VAL_33]] : !torch.vtensor<[4,4,4],f32> + // CHECK: } + %none = torch.constant.none + %0 = torch.operator "onnx.ScatterND"(%arg0, %arg1, %arg2) {torch.onnx.reduction = "min"} : (!torch.vtensor<[4,4,4],f32>, !torch.vtensor<[2,1],si64>, !torch.vtensor<[2,4,4],f32>) -> !torch.vtensor<[4,4,4],f32> + return %0 : !torch.vtensor<[4,4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_split_to_sequence_1 +func.func @test_split_to_sequence_1(%arg0: !torch.vtensor<[3,6],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]]: !torch.vtensor<[3,6],f32> + // CHECK: %[[VAL_1:.*]]: !torch.vtensor<[1],si64>) -> !torch.list> + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_5:.*]] = torch.aten.split.Tensor %[[VAL_0]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[3,6],f32>, !torch.int, !torch.int -> !torch.list> + // CHECK: return %[[VAL_5]] : !torch.list> + %none = torch.constant.none + %int1 = torch.constant.int 1 + %0 = torch.aten.item %arg1 : !torch.vtensor<[1],si64> -> !torch.int + %1 = torch.aten.split.Tensor %arg0, %0, %int1 : !torch.vtensor<[3,6],f32>, !torch.int, !torch.int -> !torch.list> + return %1 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_split_to_sequence_2 +func.func @test_split_to_sequence_2(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_0:.*]]: !torch.vtensor<[2,6],f32> + // CHECK: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.list> + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_1]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[VAL_5:.*]] = torch.aten.split.Tensor %[[VAL_0]], %[[VAL_4]], %[[VAL_3]] : !torch.vtensor<[2,6],f32>, !torch.int, !torch.int -> !torch.list> + // CHECK: return %[[VAL_5]] : !torch.list> + %none = torch.constant.none + %int0 = torch.constant.int 0 + %0 = torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + %1 = torch.aten.split.Tensor %arg0, %0, %int0 : !torch.vtensor<[2,6],f32>, !torch.int, !torch.int -> !torch.list> + return %1 : !torch.list> +} + +// ----- + +// CHECK-LABEL: func.func @test_split_to_sequence_with_list( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.aten.split.sizes %[[VAL_0]], %[[VAL_5]], %[[VAL_3]] : !torch.vtensor<[4,6],f32>, !torch.list, !torch.int -> !torch.list> +// CHECK: return %[[VAL_6]] : !torch.list> + func.func @test_split_to_sequence_with_list(%arg0: !torch.vtensor<[4,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0 = torch.operator "onnx.SplitToSequence"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4,6],f32>, !torch.vtensor<[2],si64>) -> !torch.list> + return %0 : !torch.list> + } + +// ----- + +// CHECK-LABEL: func.func @test_unique_not_sorted_without_axis +func.func @test_unique_not_sorted_without_axis(%arg0: !torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1 + // CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0_0]], %[[NEGATIVEONE]] : !torch.vtensor<[6],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32> + // CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %[[FLATTEN]], %[[INT0_0]], %[[FALSEVAL]], %[[TRUEVAL]], %[[TRUEVAL]] : !torch.vtensor<[6],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64> + // CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 6 + // CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4 + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list -> !torch.vtensor<[6],si64> + // CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list -> !torch.vtensor<[6],si64> + // CHECK: %[[OUTPUTDIMZERO:.*]] = torch.constant.int 4 + // CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIMZERO]] : (!torch.int) -> !torch.list + // CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[6],si64>, !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64> + // CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[4],si64> + // CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64> + %0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.sorted = 0 : si64} : (!torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) + return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_unique_sorted_without_axis +func.func @test_unique_sorted_without_axis(%arg0: !torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[TRUEVAL_0:.*]] = torch.constant.bool true + // CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true + // CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1 + // CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %arg0, %[[INT0_0]], %[[NEGATIVEONE]] : !torch.vtensor<[6],f32>, !torch.int, !torch.int -> !torch.vtensor<[6],f32> + // CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %[[FLATTEN]], %[[INT0_0]], %[[TRUEVAL_0]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[6],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64> + // CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 6 + // CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4 + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[6],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list -> !torch.vtensor<[6],si64> + // CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[6],si64>, !torch.list -> !torch.vtensor<[6],si64> + // CHECK: %[[OUTPUTDIMZERO:.*]] = torch.constant.int 4 + // CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIMZERO]] : (!torch.int) -> !torch.list + // CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[6],si64>, !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64> + // CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[4],si64> + // CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64> + %0:4 = torch.operator "onnx.Unique"(%arg0) : (!torch.vtensor<[6],f32>) -> (!torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64>) + return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[6],si64>, !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_unique_sorted_with_axis_3d +func.func @test_unique_sorted_with_axis_3d(%arg0: !torch.vtensor<[2,4,2],f32>) -> (!torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[TRUEVAL_0:.*]] = torch.constant.bool true + // CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true + // CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[INT1]], %[[TRUEVAL_0]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[2,4,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64> + // CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 2 + // CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4 + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + // CHECK: %[[INTO_1:.*]] = torch.constant.int 0 + // CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INTO_1]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[4],si64>, !torch.list -> !torch.vtensor<[2,4,2],si64> + // CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[2],si64>, !torch.list -> !torch.vtensor<[2],si64> + // CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2 + // CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list + // CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[2,4,2],si64>, !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + // CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[2,4,2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[3],si64> + // CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64> + %0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[2,4,2],f32>) -> (!torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64>) + return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3,2],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[4],si64>, !torch.vtensor<[3],si64> +} + +// ----- + + +// CHECK-LABEL: func.func @test_unique_sorted_with_axis +func.func @test_unique_sorted_with_axis(%arg0: !torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true + // CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[INT0_1]], %[[TRUEVAL]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[3,3],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> + // CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 3 + // CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4 + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + // CHECK: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_2]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list -> !torch.vtensor<[3,3],si64> + // CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list -> !torch.vtensor<[3],si64> + // CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2 + // CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list + // CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[3,3],si64>, !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + // CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[3,3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64> + // CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> + %0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) + return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_unique_sorted_with_negative_axis +func.func @test_unique_sorted_with_negative_axis(%arg0: !torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[NEGATIVEONE:.*]] = torch.constant.int -1 + // CHECK: %[[TRUEVAL:.*]] = torch.constant.bool true + // CHECK: %[[TRUEVAL_1:.*]] = torch.constant.bool true + // CHECK: %[[UNIQUEOUTPUT:.*]], %[[INVERSEINDEX:.*]], %[[COUNTS:.*]] = torch.aten.unique_dim %arg0, %[[NEGATIVEONE]], %[[TRUEVAL]], %[[TRUEVAL_1]], %[[TRUEVAL_1]] : !torch.vtensor<[3,3],f32>, !torch.int, !torch.bool, !torch.bool, !torch.bool -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> + // CHECK: %[[INPUTDIM0:.*]] = torch.constant.int 3 + // CHECK: %[[INT64TYPE:.*]] = torch.constant.int 4 + // CHECK: %[[NONEVAL:.*]] = torch.constant.none + // CHECK: %[[ARANGE:.*]] = torch.aten.arange %[[INPUTDIM0]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[FLIPDIMS:.*]] = torch.prim.ListConstruct %[[INT0_1]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIPINVERSE:.*]] = torch.aten.flip %[[INVERSEINDEX]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list -> !torch.vtensor<[3,3],si64> + // CHECK: %[[FLIPPERM:.*]] = torch.aten.flip %[[ARANGE]], %[[FLIPDIMS]] : !torch.vtensor<[3],si64>, !torch.list -> !torch.vtensor<[3],si64> + // CHECK: %[[OUTPUTDIM0:.*]] = torch.constant.int 2 + // CHECK: %[[NEWEMPTYSIZE:.*]] = torch.prim.ListConstruct %[[OUTPUTDIM0]] : (!torch.int) -> !torch.list + // CHECK: %[[NEWEMPTY:.*]] = torch.aten.new_empty %[[FLIPINVERSE]], %[[NEWEMPTYSIZE]], %[[INT64TYPE]], %[[NONEVAL]], %[[NONEVAL]], %[[NONEVAL]] : !torch.vtensor<[3,3],si64>, !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + // CHECK: %[[SCATTER:.*]] = torch.aten.scatter.src %[[NEWEMPTY]], %[[INT0_0]], %[[FLIPINVERSE]], %[[FLIPPERM]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[3,3],si64>, !torch.vtensor<[3],si64> -> !torch.vtensor<[2],si64> + // CHECK: return %[[UNIQUEOUTPUT]], %[[SCATTER]], %[[INVERSEINDEX]], %[[COUNTS]] : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> + %0:4 = torch.operator "onnx.Unique"(%arg0) {torch.onnx.axis = -1 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,3],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64>) + return %0#0, %0#1, %0#2, %0#3 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_scan_sum( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_scan_sum(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[VAL_2:.*]] = torch.constant.none + // CHECK: %[[VAL_3:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_5:.*]] = torch.constant.int 3 + // CHECK: %[[VAL_6:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_8:.*]] = torch.constant.none + // CHECK: %[[VAL_9:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_10:.*]] = torch.aten.full %[[VAL_7]], %[[VAL_3]], %[[VAL_9]], %[[VAL_8]], %[[VAL_8]], %[[VAL_8]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_3]] : !torch.vtensor<[3,2],f32>, !torch.int -> !torch.int + // CHECK: %[[VAL_12:.*]] = torch.constant.bool true + // CHECK: %[[VAL_13:.*]]:2 = torch.prim.Loop %[[VAL_11]], %[[VAL_12]], init(%[[VAL_0]], %[[VAL_10]]) { + // CHECK: ^bb0(%[[VAL_14:.*]]: !torch.int, %[[VAL_15:.*]]: !torch.vtensor<[2],f32>, %[[VAL_16:.*]]: !torch.vtensor<[3,2],f32>): + // CHECK: %[[VAL_17:.*]] = torch.aten.select.int %[[VAL_1]], %[[VAL_3]], %[[VAL_14]] : !torch.vtensor<[3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32> + // CHECK: %[[VAL_18:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_19:.*]] = torch.aten.add.Tensor %[[VAL_15]], %[[VAL_17]], %[[VAL_18]] : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2],f32> + // CHECK: %[[VAL_20:.*]] = torch.constant.none + // CHECK: %[[VAL_21:.*]] = torch.aten.clone %[[VAL_19]], %[[VAL_20]] : !torch.vtensor<[2],f32>, !torch.none -> !torch.vtensor<[2],f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.unsqueeze %[[VAL_21]], %[[VAL_3]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> + // CHECK: %[[VAL_23:.*]] = torch.aten.slice_scatter %[[VAL_16]], %[[VAL_22]], %[[VAL_3]], %[[VAL_14]], %[[VAL_14]], %[[VAL_4]] : !torch.vtensor<[3,2],f32>, !torch.vtensor<[1,2],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,2],f32> + // CHECK: torch.prim.Loop.condition %[[VAL_12]], iter(%[[VAL_19]], %[[VAL_23]] : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) + // CHECK: } : (!torch.int, !torch.bool, !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) + // CHECK: return %[[VAL_24:.*]]#0, %[[VAL_24]]#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32> + // CHECK: } + %none = torch.constant.none + %0:2 = torch.operator "onnx.Scan"(%arg0, %arg1) {torch.onnx.num_scan_inputs = 1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) -> (!torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32>) { + ^bb0(%arg2: !torch.vtensor<[2],f32>, %arg3: !torch.vtensor<[2],f32>): + %1 = torch.operator "onnx.Add"(%arg2, %arg3) : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> + %2 = torch.operator "onnx.Identity"(%1) : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> + torch.operator_terminator %1, %2 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> + } + return %0#0, %0#1 : !torch.vtensor<[2],f32>, !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: @test_thresholdedrelu +func.func @test_thresholdedrelu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 22 : si64} { + // CHECK: %[[FP2:.+]] = torch.constant.float 2.000000e+00 + // CHECK: %[[FP0:.+]] = torch.constant.float 0.000000e+00 + // CHECK: torch.aten.threshold %arg0, %[[FP2]], %[[FP0]] + %0 = torch.operator "onnx.ThresholdedRelu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 9f23229d5f3a..8fa13b47e588 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -16,8 +16,8 @@ func.func @torch.aten.dim(%arg0: !torch.vtensor<*,f32>) -> !torch.int { // CHECK-LABEL: func.func @torch.runtime.assert( // CHECK-SAME: %[[X:.*]]: !torch.int, // CHECK-SAME: %[[Y:.*]]: !torch.int) { -// CHECK: %[[X_I64:.*]] = torch_c.to_i64 %[[X]] -// CHECK: %[[Y_I64:.*]] = torch_c.to_i64 %[[Y]] +// CHECK-DAG: %[[X_I64:.*]] = torch_c.to_i64 %[[X]] +// CHECK-DAG: %[[Y_I64:.*]] = torch_c.to_i64 %[[Y]] // CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[X_I64]], %[[Y_I64]] : i64 // CHECK: assert %[[CMP]], "x must not be equal to y" // CHECK: return @@ -30,8 +30,8 @@ func.func @torch.runtime.assert(%arg0: !torch.int, %arg1: !torch.int) { // CHECK-LABEL: func.func @torch.aten.ne.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpi ne, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -40,11 +40,53 @@ func.func @torch.aten.ne.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo return %0 : !torch.bool } + +// CHECK-LABEL: func.func @torch.aten.ne.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[XOR:.*]] = arith.xori %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[XOR]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.ne.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.ne.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + + +// CHECK-LABEL: func.func @torch.aten.__and__.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[AND:.*]] = arith.andi %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[AND]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.__and__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.__and__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + + +// CHECK-LABEL: func.func @torch.aten.__or__.bool( +// CHECK-SAME: %[[LHS_TORCH:.*]]: !torch.bool, +// CHECK-SAME: %[[RHS_TORCH:.*]]: !torch.bool) -> !torch.bool { +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_i1 %[[LHS_TORCH]] +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_i1 %[[RHS_TORCH]] +// CHECK: %[[OR:.*]] = arith.ori %[[LHS]], %[[RHS]] : i1 +// CHECK: %[[TORCH_BOOL:.*]] = torch_c.from_i1 %[[OR]] +// CHECK: return %[[TORCH_BOOL]] : !torch.bool +func.func @torch.aten.__or__.bool(%arg0: !torch.bool, %arg1: !torch.bool) -> !torch.bool { + %0 = torch.aten.__or__.bool %arg0, %arg1 : !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.eq.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -56,8 +98,8 @@ func.func @torch.aten.eq.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo // CHECK-LABEL: func.func @torch.aten.gt.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -69,8 +111,8 @@ func.func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo // CHECK-LABEL: func.func @torch.aten.ge.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpi sge, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -83,8 +125,8 @@ func.func @torch.aten.ge.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo // CHECK-LABEL: func.func @torch.aten.lt.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -96,8 +138,8 @@ func.func @torch.aten.lt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.boo // CHECK-LABEL: func.func @torch.aten.le.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[LHS_I64]], %[[RHS_I64]] : i64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -145,8 +187,8 @@ func.func @torch.constant.int() -> !torch.int { // CHECK-LABEL: func.func @torch.aten.add.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[ADD:.*]] = arith.addi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]] // CHECK: return %[[OUT:.*]] : !torch.int @@ -158,8 +200,8 @@ func.func @torch.aten.add.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in // CHECK-LABEL: func.func @torch.aten.sub.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[SUB:.*]] = arith.subi %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[INT:.*]] // CHECK: return %[[OUT:.*]] : !torch.int @@ -171,8 +213,8 @@ func.func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in // CHECK-LABEL: func.func @torch.aten.sub.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] // CHECK: %[[SUB:.*]] = arith.subf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 // CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]] // CHECK: return %[[OUT:.*]] : !torch.float @@ -184,8 +226,8 @@ func.func @torch.aten.sub.float(%arg0: !torch.float, %arg1: !torch.float) -> !to // CHECK-LABEL: func.func @torch.aten.mul.int( // CHECK-SAME: %[[LHS:.*]]: !torch.int, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.int { -// CHECK: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[MUL:.*]] = arith.muli %[[LHS_I64:.*]], [[RHS_I64:.*]] : i64 // CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[MUL:.*]] // CHECK: return %[[OUT:.*]] : !torch.int @@ -194,11 +236,39 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in return %0 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.int_float( +// CHECK-SAME: %[[LHS:.*]]: !torch.int, +// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { +// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] +// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] +// CHECK: return %[[OUT]] : !torch.float +func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float { + %0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float + return %0 : !torch.float +} + +// CHECK-LABEL: func.func @torch.aten.mul.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.float { +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] +// CHECK: return %[[OUT]] : !torch.float +func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { + %0 = torch.aten.mul.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.div.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] // CHECK: %[[SUB:.*]] = arith.divf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 // CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]] // CHECK: return %[[OUT:.*]] : !torch.float @@ -210,8 +280,8 @@ func.func @torch.aten.div.float(%arg0: !torch.float, %arg1: !torch.float) -> !to // CHECK-LABEL: func.func @torch.aten.ge.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.bool { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] // CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] // CHECK: return %[[CMP_TORCH_BOOL]] : !torch.bool @@ -223,8 +293,8 @@ func.func @torch.aten.ge.float(%arg0: !torch.float, %arg1: !torch.float) -> !tor // CHECK-LABEL: func.func @torch.aten.ge.float_int( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 // CHECK: %[[CMP:.*]] = arith.cmpf uge, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] @@ -237,8 +307,8 @@ func.func @torch.aten.ge.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !t // CHECK-LABEL: func.func @torch.aten.ne.float_int( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 // CHECK: %[[CMP:.*]] = arith.cmpf une, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] @@ -263,8 +333,8 @@ func.func @torch.aten.ceil.float(%arg0: !torch.float) -> !torch.int { // CHECK-LABEL: func.func @torch.aten.gt.float_int( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.bool { -// CHECK: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 // CHECK: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS_F64]], %[[RHS_F64]] : f64 // CHECK: %[[CMP_TORCH_BOOL:.*]] = torch_c.from_i1 %[[CMP]] @@ -326,3 +396,14 @@ func.func @torch.aten.Int.bool(%arg0: !torch.bool) -> !torch.int { %0 = torch.aten.Int.bool %arg0 : !torch.bool -> !torch.int return %0 : !torch.int } + +// CHECK-LABEL: func.func @torch.aten.Int.Scalar( +// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.int { +// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]] +// CHECK: %[[FPTOSI:.*]] = arith.fptosi %[[ARG_F64]] : f64 to i64 +// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[FPTOSI]] +// CHECK: return %[[OUT]] : !torch.int +func.func @torch.aten.Int.Scalar(%arg0: !torch.float) -> !torch.int { + %0 = torch.aten.Int.Scalar %arg0 : !torch.float -> !torch.int + return %0 : !torch.int +} diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index f063f234e4e5..1b61f75703f6 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: func.func @torch.aten.mm$basic( // CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { -// CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[RHS:.*]] = torch_c.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[RHS:.*]] = torch_c.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -55,7 +55,7 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- // CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned( -// CHECK: linalg.matmul_unsigned +// CHECK: linalg.matmul {cast = #linalg.type_fn} func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32> attributes {torch.assume_strict_symbolic_shapes} { @@ -339,3 +339,18 @@ func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> return %1 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.transpose$basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[IN_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> +// CHECK: %[[TRANSP:.*]] = linalg.transpose ins(%[[IN_0]] : tensor<4x3xf32>) outs(%1 : tensor<3x4xf32>) permutation = [1, 0] +// CHECK: %[[OUT_0:.*]] = torch_c.from_builtin_tensor %{{.*}} : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[OUT_0]] : !torch.vtensor<[3,4],f32> +func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.transpose.int %arg0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir new file mode 100644 index 000000000000..19075d72103a --- /dev/null +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -0,0 +1,30 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.sym_constrain_range( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch_c.to_i64 %[[VAL_0]] +// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_6:.*]] = arith.constant 9223372036854775807 : i64 +// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_6]] : i64 +// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_7]], %[[VAL_8]] : i1 +// CHECK: cf.assert %[[VAL_9]], "Size constraint failed. Expected range: [0, 9223372036854775807]" +// CHECK: %[[VAL_10:.*]] = arith.constant 0 : i64 +// CHECK: %[[VAL_11:.*]] = arith.constant 7 : i64 +// CHECK: %[[VAL_12:.*]] = arith.cmpi sle, %[[VAL_10]], %[[VAL_1]] : i64 +// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_1]], %[[VAL_11]] : i64 +// CHECK: %[[VAL_14:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : i1 +// CHECK: cf.assert %[[VAL_14]], "Size constraint failed. Expected range: [0, 7]" +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten.sym_constrain_range(%arg0: !torch.int) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + torch.aten.sym_constrain_range %arg0, %int0, %none : !torch.int, !torch.int, !torch.none + torch.aten.sym_constrain_range %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %arg0 : !torch.int +} diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 1fead662183e..480b1eeb9ed2 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -16,3 +16,63 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128] %4 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %false, %3, %int1 : !torch.vtensor<[1,24,16,128,128],f16>, !torch.vtensor<[54,24,1,1,1],f16>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,54,16,128,128],f16> return %4 : !torch.vtensor<[1,54,16,128,128],f16> } + +// ----- + +// CHECK-LABEL: func.func @q_conv_test +// CHECK: %[[c3:.*]] = arith.constant 3 : i32 +// CHECK: %[[c7:.*]] = arith.constant 7 : i32 +// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor +// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor +// CHECK: %[[conv:.*]] = linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} +// CHECK-SAME: ins(%[[input]], %[[weight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) +// CHECK-SAME: outs(%[[convout:.*]] : tensor) -> tensor +func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %float1.000000e-04 = torch.constant.float 1.000000e-04 + %int3 = torch.constant.int 3 + %int7 = torch.constant.int 7 + %float1.000000e-02 = torch.constant.float 1.000000e-02 + %int14 = torch.constant.int 14 + %0 = torch.aten.quantize_per_tensor %arg2, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32> + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],f32> + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float1.000000e-02, %int7 : !torch.vtensor<[?,?,?,?],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint8> + %6 = torch.aten._make_per_tensor_quantized_tensor %arg1, %float1.000000e-02, %int3 : !torch.vtensor<[?,?,?,?],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint8> + %7 = torch.aten.quantize_per_tensor %1, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32> + %8 = torch.aten.int_repr %7 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],si32> + %9 = torch.aten.convolution %5, %6, %8, %2, %3, %2, %false, %4, %int1 : !torch.vtensor<[?,?,?,?],!torch.qint8>, !torch.vtensor<[?,?,?,?],!torch.qint8>, !torch.vtensor<[?],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],si32> + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float1.000000e-04, %int0 : !torch.vtensor<[?,?,?,?],si32>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint32> + %11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,?,?,?],!torch.qint32> -> !torch.vtensor<[?,?,?,?],f32> + return %11 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @conv_broadcast( +// CHECK-SAME: %[[arg0:.*]]: !torch.vtensor<[1,80,3000],f32>, +// CHECK-SAME: %[[arg1:.*]]: !torch.vtensor<[1024,80,3],f32>, +// CHECK-SAME: %[[arg2:.*]]: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> { +// CHECK: %[[c0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[input:.*]] = torch_c.to_builtin_tensor %[[arg0]] : !torch.vtensor<[1,80,3000],f32> -> tensor<1x80x3000xf32> +// CHECK-DAG: %[[weight:.*]] = torch_c.to_builtin_tensor %[[arg1]] : !torch.vtensor<[1024,80,3],f32> -> tensor<1024x80x3xf32> +// CHECK-DAG: %[[bias:.*]] = torch_c.to_builtin_tensor %[[arg2]] : !torch.vtensor<[1024],f32> -> tensor<1024xf32> +// CHECK: %[[padInput:.*]] = tensor.pad %[[input]] low[0, 0, 1] high[0, 0, 1] +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1024x3000xf32> +// CHECK: %[[broadcastBias:.*]] = linalg.broadcast ins(%[[bias]] : tensor<1024xf32>) outs(%[[EMPTY]] : tensor<1x1024x3000xf32>) dimensions = [0, 2] +// CHECK: %[[conv:.*]] = linalg.conv_1d_ncw_fcw {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} +// CHECK-SAME: ins(%[[padInput:.*]], %[[weight]] : tensor<1x80x3002xf32>, tensor<1024x80x3xf32>) +// CHECK-SAME: outs(%[[broadcastBias]] : tensor<1x1024x3000xf32>) -> tensor<1x1024x3000xf32> +func.func @conv_broadcast(%arg0: !torch.vtensor<[1,80,3000],f32>, %arg1: !torch.vtensor<[1024,80,3],f32>, %arg2: !torch.vtensor<[1024],f32>) -> !torch.vtensor<[1,1024,3000],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %2 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %0, %0, %false, %1, %int1 : !torch.vtensor<[1,80,3000],f32>, !torch.vtensor<[1024,80,3],f32>, !torch.vtensor<[1024],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1024,3000],f32> + return %2 : !torch.vtensor<[1,1024,3000],f32> +} diff --git a/test/Conversion/TorchToLinalg/datamovement.mlir b/test/Conversion/TorchToLinalg/datamovement.mlir new file mode 100644 index 000000000000..dd5e5c553d31 --- /dev/null +++ b/test/Conversion/TorchToLinalg/datamovement.mlir @@ -0,0 +1,34 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.permute( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,8,4,32,16],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64,32,16,8,4],f32> -> tensor<64x32x16x8x4xf32> +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<64x8x4x32x16xf32> +// CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<64x32x16x8x4xf32>) outs(%[[VAL_2]] : tensor<64x8x4x32x16xf32>) permutation = [0, 3, 4, 1, 2] +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<64x8x4x32x16xf32> -> !torch.vtensor<[64,8,4,32,16],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[64,8,4,32,16],f32> +// CHECK: } +func.func @torch.aten.permute(%arg0: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,8,4,32,16],f32> { + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int0, %int3, %int4, %int1, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[64,32,16,8,4],f32>, !torch.list -> !torch.vtensor<[64,8,4,32,16],f32> + return %1 : !torch.vtensor<[64,8,4,32,16],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.permute$rank0( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[VAL_2]] : !torch.vtensor<[],f32> +// CHECK: } +func.func @torch.aten.permute$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.prim.ListConstruct : () -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[],f32> + return %1 : !torch.vtensor<[],f32> +} diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index bed94f98da2b..c8fdeded44df 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor +// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): @@ -24,8 +24,8 @@ func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[] // CHECK-LABEL: func.func @elementwise$binary( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[BUILTIN_ARG1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor +// CHECK-DAG: %[[BUILTIN_ARG0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[BUILTIN_ARG1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[ARG0_DIM0:.*]] = tensor.dim %[[BUILTIN_ARG0]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -67,7 +67,7 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[C1:.*]] = torch.constant.int 1 -// CHECK: %[[BUILTIN_C1:.*]] = torch_c.to_i64 %[[C1]] +// CHECK: %[[BUILTIN_C1:.*]] = arith.constant 1 : i64 // CHECK: linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>] // CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32, %{{.*}}: f32): // CHECK: %[[ALPHA:.*]] = arith.sitofp %[[BUILTIN_C1]] : i64 to f32 @@ -102,3 +102,19 @@ func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } + +// ----- + +// CHECK-LABEL: func.func @elementwise_todtype_bf162f16( +// CHECK: linalg.generic +// CHECK: arith.extf +// CHECK-SAME: bf16 to f32 +// CHECK: arith.truncf +// CHECK-SAME: f32 to f16 +func.func @elementwise_todtype_bf162f16(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16> + return %0 : !torch.vtensor<[1,?,32,128],f16> +} diff --git a/test/Conversion/TorchToLinalg/embeddingBag.mlir b/test/Conversion/TorchToLinalg/embeddingBag.mlir new file mode 100644 index 000000000000..05aa57fc751a --- /dev/null +++ b/test/Conversion/TorchToLinalg/embeddingBag.mlir @@ -0,0 +1,52 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d1)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-LABEL: func.func @torchAtenEmbeddingBagPaddingIdx +// CHECK: %[[VAL_0:.*]]: !torch.vtensor<[1000000,64],f32> +// CHECK: %[[VAL_1:.*]]: !torch.vtensor<[204790],si64> +// CHECK: %[[VAL_2:.*]]: !torch.vtensor<[2048],si64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2048],si64> -> tensor<2048xi64> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[204790],si64> -> tensor<204790xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1000000,64],f32> -> tensor<1000000x64xf32> +// CHECK-DAG: %[[VAL_6:.*]] = torch.constant.bool true +// CHECK-DAG: %[[VAL_7:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[VAL_8:.*]] = torch.constant.bool true +func.func @torchAtenEmbeddingBagPaddingIdx(%weight: !torch.vtensor<[1000000,64],f32>, + %indices: !torch.vtensor<[204790],si64>, + %offsets: !torch.vtensor<[2048],si64>) -> (!torch.vtensor<[2048,64],f32>, + !torch.vtensor<[0],si64>, + !torch.vtensor<[2048],si64>, + !torch.vtensor<[2048],si64>) + { + %scale_grad_by_freq = torch.constant.bool true + %mode = torch.constant.int 0 + %sparse = torch.constant.bool true + %per_sample_weights = torch.constant.none + %include_last_offset = torch.constant.bool false + %padding_idx = torch.constant.none + %result0, %result1, %result2, %result3 = torch.aten.embedding_bag.padding_idx %weight, + %indices, + %offsets, + %scale_grad_by_freq, + %mode, + %sparse, + %per_sample_weights, + %include_last_offset, + %padding_idx : + !torch.vtensor<[1000000,64],f32>, + !torch.vtensor<[204790],si64>, + !torch.vtensor<[2048],si64>, + !torch.bool, + !torch.int, + !torch.bool, + !torch.none, + !torch.bool, + !torch.none -> !torch.vtensor<[2048,64],f32>, + !torch.vtensor<[0],si64>, + !torch.vtensor<[2048],si64>, + !torch.vtensor<[2048],si64> + + return %result0, %result1, %result2, %result3 : !torch.vtensor<[2048,64],f32>, !torch.vtensor<[0],si64>, !torch.vtensor<[2048],si64>, !torch.vtensor<[2048],si64> +} diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index 7c099c5ce4f6..2a291f721fed 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -5,9 +5,7 @@ // CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> // CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> // CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 494f603c296e..558c50c4f08f 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -7,13 +7,13 @@ func.func @forward_max_pool1d(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vten %int3 = torch.constant.int 3 %int4 = torch.constant.int 4 %false = torch.constant.bool false - // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 + // CHECK: %[[C1:.*]] = arith.constant 1 : i64 // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 3] high[0, 0, 3] // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index - // CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]]) : tensor - // CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + // CHECK: %[[T1:.*]] = arith.constant 1 : index + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32> + // CHECK: linalg.pooling_ncw_max {dilations = dense<4> : vector<1xi64>, strides = dense<2> : vector<1xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor<1xf32>) outs(%[[OUT]] : tensor) -> tensor %kernel_size = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list @@ -33,15 +33,15 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt %int7 = torch.constant.int 7 %int8 = torch.constant.int 8 %false = torch.constant.bool false - // CHECK: %[[C1:.*]] = torch_c.to_i64 %int1 - // CHECK: %[[C2:.*]] = torch_c.to_i64 %int2 + // CHECK: %[[C1:.*]] = arith.constant 1 : i64 + // CHECK: %[[C2:.*]] = arith.constant 2 : i64 // CHECK: %[[NEUTRAL:.*]] = arith.constant 0xFF800000 : f32 // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] // CHECK: %[[OUT:.*]] = linalg.fill ins(%[[NEUTRAL]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[T1:.*]] = arith.index_cast %[[C1]] : i64 to index - // CHECK: %[[T2:.*]] = arith.index_cast %[[C2]] : i64 to index - // CHECK: %[[INIT:.*]] = tensor.empty(%[[T1]], %[[T2]]) : tensor - // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor) outs(%[[OUT]] : tensor) -> tensor + // CHECK: %[[T1:.*]] = arith.constant 1 : index + // CHECK: %[[T2:.*]] = arith.constant 2 : index + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1x2xf32> + // CHECK: linalg.pooling_nchw_max {dilations = dense<[7, 8]> : vector<2xi64>, strides = dense<[3, 4]> : vector<2xi64>} ins(%[[PADDED]], %[[INIT]] : tensor, tensor<1x2xf32>) outs(%[[OUT]] : tensor) -> tensor %kernel_size = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int5, %int6 : (!torch.int, !torch.int) -> !torch.list @@ -88,7 +88,7 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } : tensor to tensor // CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor) -> tensor - // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { + // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor<8x8x8xf32>) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32): // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 480454b3f1fc..1dfe45492312 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -3,78 +3,47 @@ // CHECK-LABEL: func.func @test_resize_sizes_linear func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] ,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[generic:.*]] = linalg.generic - // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 - // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 - // CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32 - // CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32 - // CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32 - // CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32 - // CHECK: %[[x13:.*]] = linalg.index 2 : index - // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 - // CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32 - // CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64 - // CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32 - // CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32 - // CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32 - // CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32 - // CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32 - // CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32 - // CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32 - // CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 - // CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 - // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32 - // CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64 - // CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32 - // CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32 - // CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32 - // CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32 - // CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32 - // CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32 - // CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32 - // CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32 - // CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32 - // CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32 - // CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32 - // CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32 - // CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32 - // CHECK: %[[x43:.*]] = linalg.index 0 : index - // CHECK: %[[x44:.*]] = linalg.index 1 : index - // CHECK: %[[x45:.*]] = linalg.index 2 : index - // CHECK: %[[x46:.*]] = linalg.index 3 : index - // CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64 - // CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index - // CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64 - // CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index - // CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64 - // CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index - // CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64 - // CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32> - // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32> - // CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32 - // CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32 - // CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32 - // CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32 - // CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32 - // CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32 - // CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32 - // CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32 - // CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32 - // CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32 - // CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32 - // CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32 - // CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32 - // CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32 - // CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32 - // CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32 - // CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32 - // CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32 - // CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32 + // CHECK: %[[x0:.*]] = torch_c.to_builtin_tensor %arg0 + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK-DAG: %[[cst:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK-DAG: %[[cst_4:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[x15:.*]] = linalg.index 0 : index + // CHECK-DAG: %[[x16:.*]] = linalg.index 1 : index + // CHECK-DAG: %[[x17:.*]] = linalg.index 2 : index + // CHECK-DAG: %[[x18:.*]] = linalg.index 3 : index + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x8:.*]] : i64 to f32 + // CHECK-DAG: %[[x21:.*]] = arith.divf %[[x20]], %[[x19]] : f32 + // CHECK-DAG: %[[x22:.*]] = arith.index_cast %[[x17]] : index to i64 + // CHECK-DAG: %[[x23:.*]] = arith.sitofp %[[x22]] : i64 to f32 + // CHECK-DAG: %[[x24:.*]] = arith.addf %[[x23]], %[[cst_4]] : f32 + // CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32 + // CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32 + // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32 + // CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32 + // CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32 + // CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32 + // CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32 + // CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 + // CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index + // CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32 + // CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64 + // CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[high:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x37]], %[[low]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x37]], %[[high]]] : tensor<1x1x2x4xf32> + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] %none = torch.constant.none %none_0 = torch.constant.none %int0 = torch.constant.int 0 @@ -94,35 +63,105 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // ----- +// CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[GENERIC:.*]] = linalg.generic - // CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64 - // CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64 + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index - // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32 + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 - // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 - // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x26:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x26]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 + // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 // CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64 // CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32 // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x28]] : f32 + // CHECK: %[[x33:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_1d +func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 - // CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32 // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index - // CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 - // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index - // CHECK: %[[x35:.*]] = linalg.index 0 : index - // CHECK: %[[x36:.*]] = linalg.index 1 : index - // CHECK: %[[x37:.*]] = linalg.index 2 : index - // CHECK: %[[x38:.*]] = linalg.index 3 : index - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest,floor" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_3d +func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[index4:.*]] = linalg.index 4 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[floor:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[floor]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x34:.*]] = arith.index_cast %[[Wfptosi:.*]] : i64 to index + // CHECK: %[[x35:.*]] = arith.index_cast %[[Dfptosi:.*]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]], %[[x35]]] : tensor // CHECK: linalg.yield %[[extracted]] : f32 %none = torch.constant.none %none_0 = torch.constant.none @@ -131,6 +170,177 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 %true = torch.constant.bool true %str = torch.constant.str "nearest" %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %int4 = torch.constant.int 4 + %4 = torch.aten.select.int %arg1, %int0, %int4 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %5 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %6 = torch.prim.ListConstruct %1, %3, %5: (!torch.int, !torch.int, !torch.int) -> !torch.list + %7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> + return %7 : !torch.vtensor<[?,?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_ceil +func.func @test_resize_nearest_ceil(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 1.000000e+00 : f32 + // CHECK: %[[nM1:.*]] = arith.subf %[[inputsizefp:.*]], %[[cst3]] + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[minindex:.*]] = arith.minimumf %[[ceil]], %[[nM1]] + // CHECK: %[[x31:.*]] = arith.fptosi %[[minindex]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,ceil" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_resize_scales_linear_half_pixel_symmetric +func.func @test_resize_scales_linear_half_pixel_symmetric(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,f64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[generic:.*]] = linalg.generic + // CHECK: %[[cst7:.*]] = arith.constant 2.0 + // CHECK: %[[halfsize:.*]] = arith.divf %[[sizefp:.*]], %[[cst7]] + // CHECK: %[[modifier:.*]] = arith.subf %[[cstOne:.*]], %[[adjustment:.*]] + // CHECK: %[[offset:.*]] = arith.mulf %[[halfsize]], %[[modifier]] + // CHECK: %[[preClip:.*]] = arith.addf %[[offset]], %[[halfpixelbase:.*]] + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]] + // CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]] + // CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]] + // CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]] + // CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]] + // CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]] + // CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]] + // CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]] + // CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]] + // CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]] + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "bilinear_half_pixel_symmetric" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],f64> -> !torch.float + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],f64>, !torch.int, !torch.int -> !torch.vtensor<[1],f64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],f64> -> !torch.float + %4 = torch.prim.ListConstruct %1, %3 : (!torch.float, !torch.float) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %none_0, %4, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_resize_nearest_half_pixel_round_prefer_floor +func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[cst:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[add:.*]] = arith.addf %[[x24]], %[[cst]] : f32 + // CHECK: %[[x25:.*]] = arith.divf %[[add]], %[[x21]] : f32 + // CHECK: %[[sub:.*]] = arith.subf %[[x25]], %[[cst]] : f32 + // CHECK: %[[cst3:.*]] = arith.constant 5.000000e-01 : f32 + // CHECK: %[[floor:.*]] = math.floor %[[sub]] : f32 + // CHECK: %[[ceil:.*]] = math.ceil %[[sub]] : f32 + // CHECK: %[[sub2:.*]] = arith.subf %[[sub]], %[[floor]] : f32 + // CHECK: %[[cmpf:.*]] = arith.cmpf ule, %[[sub2]], %[[cst3]] : f32 + // CHECK: %[[select:.*]] = arith.select %[[cmpf]], %[[floor]], %[[ceil]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[select]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest_half_pixel,round_prefer_floor" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// CHECK-LABEL: func.func @test_resize_sizes_cubic +func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 +: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[x1:.*]] = math.ceil %36 : f32 + // CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst_5 : f32 + // CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst_5 : f32 + // CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst_5 : f32 + // CHECK-DAG: %[[y1:.*]] = math.ceil %28 : f32 + // CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst_5 : f32 + // CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst_5 : f32 + // CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst_5 : f32 + // CHECK-DAG: %[[y2D:.*]] = arith.subf %28, %[[y2]] : f32 + // CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32 + // CHECK-DAG: %[[y1D:.*]] = arith.subf %28, %[[y1]] : f32 + // CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32 + // CHECK-DAG: %[[y_1D:.*]] = arith.subf %28, %[[y_1]] : f32 + // CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32 + // CHECK-DAG: %[[y_2D:.*]] = arith.subf %28, %[[y_2]] : f32 + // CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32 + // CHECK-DAG: %[[x2D:.*]] = arith.subf %36, %[[x2]] : f32 + // CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32 + // CHECK-DAG: %[[x1D:.*]] = arith.subf %36, %[[x1]] : f32 + // CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32 + // CHECK-DAG: %[[x_1D:.*]] = arith.subf %36, %[[x_1]] : f32 + // CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32 + // CHECK-DAG: %[[x_2D:.*]] = arith.subf %36, %[[x_2]] : f32 + // CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32 + // CHECK-DAG: %[[distSQ:.*]] = arith.mulf %52, %52 : f32 + // CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %52 : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "cubic" + %int2 = torch.constant.int 2 %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int %int3 = torch.constant.int 3 @@ -139,4 +349,6 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %5 : !torch.vtensor<[?,?,?,?],f32> - } +} + +// ----- diff --git a/test/Conversion/TorchToLinalg/sparse.mlir b/test/Conversion/TorchToLinalg/sparse.mlir index f343aedf5545..2ebaccc55a4a 100644 --- a/test/Conversion/TorchToLinalg/sparse.mlir +++ b/test/Conversion/TorchToLinalg/sparse.mlir @@ -24,8 +24,8 @@ func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32 // CHECK-LABEL: func.func @SpMM( // CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>, // CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> -// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]> -// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> +// CHECK-DAG: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]> +// CHECK-DAG: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> // CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>) func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir new file mode 100644 index 000000000000..abd45183bd84 --- /dev/null +++ b/test/Conversion/TorchToLinalg/spectral.mlir @@ -0,0 +1,64 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex> +// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> +// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex> +// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex>) -> tensor<16x5xcomplex> +// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex>) outs(%[[VAR2]] : tensor<16x5xcomplex>) { +// CHECK: ^bb0(%in: f32, %in_1: complex, %out: complex): +// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex +// CHECK: %[[VAR6:.*]] = complex.im %in_1 : complex +// CHECK: %[[VAR7:.*]] = arith.mulf %in, %[[VAR5]] : f32 +// CHECK: %[[VAR8:.*]] = arith.mulf %in, %[[VAR6]] : f32 +// CHECK: %[[VAR9:.*]] = complex.create %[[VAR7]], %[[VAR8]] : complex +// CHECK: %[[VAR10:.*]] = complex.add %[[VAR9]], %out : complex +// CHECK: linalg.yield %[[VAR10]] : complex +// CHECK: } -> tensor<16x5xcomplex> +// CHECK: %[[VAR4:.*]] = torch_c.from_builtin_tensor %[[VAR3]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> + +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xcomplex> +// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> +// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32> +// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0] +// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex> +// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex>) -> tensor<23x19xcomplex> +// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex>) outs(%[[VAR3]] : tensor<23x19xcomplex>) { +// CHECK: ^bb0(%in: f32, %in_2: complex, %out: complex): +// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex +// CHECK: %[[VAR8:.*]] = complex.im %in_2 : complex +// CHECK: %[[VAR9:.*]] = arith.mulf %in, %[[VAR7]] : f32 +// CHECK: %[[VAR10:.*]] = arith.mulf %in, %[[VAR8]] : f32 +// CHECK: %[[VAR11:.*]] = complex.create %[[VAR9]], %[[VAR10]] : complex +// CHECK: %[[VAR12:.*]] = complex.add %[[VAR11]], %out : complex +// CHECK: linalg.yield %[[VAR12]] : complex +// CHECK: } -> tensor<23x19xcomplex> +// CHECK-DAG: %[[VAR5:.*]] = tensor.empty() : tensor<19x23xcomplex> +// CHECK: %[[TRANSPOSED_1:.*]] = linalg.transpose ins(%[[VAR4]] : tensor<23x19xcomplex>) outs(%[[VAR5]] : tensor<19x23xcomplex>) permutation = [1, 0] +// CHECK: %[[VAR6:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_1]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +} diff --git a/test/Conversion/TorchToLinalg/squeeze.mlir b/test/Conversion/TorchToLinalg/squeeze.mlir new file mode 100644 index 000000000000..a8922eed5a9d --- /dev/null +++ b/test/Conversion/TorchToLinalg/squeeze.mlir @@ -0,0 +1,17 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @torch.aten.squeeze.dim$dynamic +func.func @torch.aten.squeeze.dim$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "tf2onnx", torch.onnx_meta.producer_version = "1.5.2"} { + // CHECK: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?],f32> -> tensor + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[DIM:.*]] = tensor.dim %[[BUILTIN_TENSOR]], %[[C0_1]] : tensor + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[DIM]], %[[C1]] : index + // CHECK: cf.assert %[[CMPI]], "Expected dynamic squeeze dim size to be statically 1" + // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[BUILTIN_TENSOR]] {{\[\[}}0, 1], [2]] : tensor into tensor + // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[COLLAPSED]] : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %1 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 3d265a308a0d..2da7c0b74fc2 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -281,3 +281,30 @@ func.func @torch.aten.view$dynamicInferredSame(%arg0: !torch.vtensor<[10,?,2,3], %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[10,?,2,3],f32>, !torch.list -> !torch.vtensor<[2,5,?,6],f32> return %1 : !torch.vtensor<[2,5,?,6],f32> } + +// ----- + +// this is to check a path for unflatten.int with two dynamic reassociation dims +// the IR here is generated from the onnx.Gather conversion +// CHECK-LABEL: @gather_graph +// CHECK: %[[fromelt:.*]] = tensor.from_elements +// CHECK-SAME: tensor<3xi64> +// CHECK: %[[reshape:.*]] = tensor.reshape +// CHECK-SAME: (tensor, tensor<3xi64>) -> tensor +func.func @gather_graph(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?,3],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + %int-1 = torch.constant.int -1 + %int5 = torch.constant.int 5 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.lt.Scalar %arg1, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],i1> + %1 = torch.aten.add.Scalar %arg1, %int5, %int1 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?],si64> + %2 = torch.aten.where.self %0, %1, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> + %3 = torch.aten.size.int %2, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.aten.size.int %2, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %5 = torch.prim.ListConstruct %3, %4 : (!torch.int, !torch.int) -> !torch.list + %6 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %7 = torch.aten.view %2, %6 : !torch.vtensor<[?,?],si64>, !torch.list -> !torch.vtensor<[?],si64> + %8 = torch.aten.index_select %arg0, %int0, %7 : !torch.vtensor<[5,3],f32>, !torch.int, !torch.vtensor<[?],si64> -> !torch.vtensor<[?,3],f32> + %9 = torch.aten.unflatten.int %8, %int0, %5 : !torch.vtensor<[?,3],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,3],f32> + return %9 : !torch.vtensor<[?,?,3],f32> +} diff --git a/test/Conversion/TorchToLinalg/view_strict.mlir b/test/Conversion/TorchToLinalg/view_strict.mlir index 8be9a2f9fb5a..a900fbb06927 100644 --- a/test/Conversion/TorchToLinalg/view_strict.mlir +++ b/test/Conversion/TorchToLinalg/view_strict.mlir @@ -7,10 +7,8 @@ // CHECK-LABEL: func.func @torch.aten.view$twotothree // CHECK: %[[ARG0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,2],f32> -> tensor<3x2xf32> -// CHECK: %[[T3:.*]] = torch.constant.int 3 -// CHECK: %[[T2:.*]] = torch.constant.int 2 -// CHECK: %[[N2:.*]] = torch_c.to_i64 %[[T2]] -// CHECK: %[[N3:.*]] = torch_c.to_i64 %[[T3]] +// CHECK: %[[N2:.*]] = arith.constant 2 : i64 +// CHECK: %[[N3:.*]] = arith.constant 3 : i64 // CHECK: %[[ELEMENTS:.*]] = tensor.from_elements %[[N2]], %[[N3]] : tensor<2xi64> // CHECK: %[[RESHAPE:.*]] = tensor.reshape %[[ARG0]](%[[ELEMENTS]]) : (tensor<3x2xf32>, tensor<2xi64>) -> tensor<2x3xf32> func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torch.vtensor<[2,3],f32> @@ -112,13 +110,12 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // reshape. Someday, this should generate flatten/unflatten. // CHECK-LABEL: func.func @torch.aten$dynamicValOutput // CHECK: %[[SELF:.*]] = torch_c.to_builtin_tensor %arg0 -// CHECK: %[[CONSTANT1:.*]] = torch.constant.int 1 // CHECK-DAG: %[[PROD1:.*]] = arith.constant 1 // CHECK-DAG: %[[ARG1_CVT:.*]] = torch_c.to_i64 %arg1 // CHECK-DAG: %[[PROD2:.*]] = arith.muli %[[PROD1]], %[[ARG1_CVT]] -// CHECK-DAG: %[[ONEI64:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK-DAG: %[[ONEI64:.*]] = arith.constant 1 : i64 // CHECK-DAG: %[[PROD3:.*]] = arith.muli %[[PROD2]], %[[ONEI64]] -// CHECK-DAG: %[[ONEI64_0:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK-DAG: %[[ONEI64_0:.*]] = arith.constant 1 : i64 // CHECK-DAG: %[[PROD4:.*]] = arith.muli %[[PROD3]], %[[ONEI64_0]] // CHECK-DAG: %[[INDEX0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[DIM0_INDEX:.*]] = tensor.dim %[[SELF]], %[[INDEX0]] : tensor @@ -134,8 +131,8 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // CHECK-DAG: %[[KNOWN2:.*]] = arith.muli %[[KNOWN1]], %[[DIM2]] : i64 // CHECK-DAG: %[[DIMINFER:.*]] = arith.divui %[[KNOWN2]], %[[PROD4]] : i64 // CHECK: %[[DIM0:.*]] = torch_c.to_i64 %arg1 -// CHECK: %[[DIM1:.*]] = torch_c.to_i64 %[[CONSTANT1]] -// CHECK: %[[DIM3:.*]] = torch_c.to_i64 %[[CONSTANT1]] +// CHECK: %[[DIM1:.*]] = arith.constant 1 : i64 +// CHECK: %[[DIM3:.*]] = arith.constant 1 : i64 // CHECK: %[[OUTPUT_DIMS:.*]] = tensor.from_elements %[[DIM0]], %[[DIM1]], %[[DIMINFER]], %[[DIM3]] : tensor<4xi64> // CHECK: tensor.reshape %[[SELF]](%[[OUTPUT_DIMS]]) : (tensor, tensor<4xi64>) -> tensor // diff --git a/test/Conversion/TorchToSCF/basic.mlir b/test/Conversion/TorchToSCF/basic.mlir index fa4f46f044ca..ccd1b7998e99 100644 --- a/test/Conversion/TorchToSCF/basic.mlir +++ b/test/Conversion/TorchToSCF/basic.mlir @@ -4,9 +4,9 @@ // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { // CHECK: %[[VAL_1:.*]] = torch_c.to_i1 %[[VAL_0]] // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = torch_c.to_i64 %[[VAL_2]] +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : i64 // CHECK: %[[VAL_6:.*]] = scf.if %[[VAL_1]] -> (i64) { // CHECK: scf.yield %[[VAL_3]] : i64 // CHECK: } else { @@ -28,14 +28,14 @@ func.func @torch.prim.if(%arg0: !torch.bool) -> !torch.int { // CHECK-LABEL: func.func @aten.prim.if$nested( // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool, // CHECK-SAME: %[[VAL_1:.*]]: !torch.bool) -> !torch.int { -// CHECK: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]] -// CHECK: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_1]] +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_i1 %[[VAL_0]] +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_i1 %[[VAL_1]] // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]] +// CHECK: %[[VAL_5:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_6:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] +// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i64 // CHECK: %[[VAL_8:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_9:.*]] = torch_c.to_i64 %[[VAL_8]] +// CHECK: %[[VAL_9:.*]] = arith.constant 4 : i64 // CHECK: %[[VAL_10:.*]] = scf.if %[[VAL_2]] -> (i64) { // CHECK: %[[VAL_11:.*]] = scf.if %[[VAL_3]] -> (i64) { // CHECK: scf.yield %[[VAL_5]] : i64 @@ -124,8 +124,8 @@ func.func @torch.prim.loop$while(%arg0: !torch.int) -> !torch.float { // CHECK-NEXT: %[[VAL_1:.*]] = torch_c.to_f64 %[[TORCH_VAL_1]] // CHECK-NEXT: scf.yield %[[BLOCK_CONDITION]], %[[VAL_0]], %[[VAL_1]] : i1, f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0 -// CHECK-NEXT: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1 +// CHECK-DAG: %[[TORCH_LOOP_0:.*]] = torch_c.from_f64 %[[LOOP]]#0 +// CHECK-DAG: %[[TORCH_LOOP_1:.*]] = torch_c.from_f64 %[[LOOP]]#1 // CHECK-NEXT: return %[[TORCH_LOOP_0]], %[[TORCH_LOOP_1]] : !torch.float, !torch.float func.func @torch.prim.loop$while_with_multiple_values() -> (!torch.float, !torch.float) { %float3.200000e00 = torch.constant.float 3.200000e+00 @@ -198,8 +198,8 @@ func.func @torch.prim.Loop$for(%arg0: !torch.int) -> !torch.float { // CHECK-NEXT: %[[VAL_1:.*]] = torch_c.to_f64 %[[TORCH_VAL_1]] // CHECK-NEXT: scf.yield %[[VAL_0]], %[[VAL_1]] : f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: %[[RETURN_0:.*]] = torch_c.from_f64 %[[LOOP]]#0 -// CHECK-NEXT: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1 +// CHECK-DAG: %[[RETURN_0:.*]] = torch_c.from_f64 %[[LOOP]]#0 +// CHECK-DAG: %[[RETURN_1:.*]] = torch_c.from_f64 %[[LOOP]]#1 // CHECK-NEXT: return %[[RETURN_0]], %[[RETURN_1]] : !torch.float, !torch.float // CHECK-NEXT: } func.func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!torch.float, !torch.float) { diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 30f8716ebdf0..c46328095440 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -40,10 +40,8 @@ func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { // CHECK-LABEL: func.func @torch.aten.contiguous( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> // CHECK: %int0 = torch.constant.int 0 -// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor<4x64xf32> -> !torch.vtensor<[4,64],f32> -// CHECK: return %[[VAL_2]] : !torch.vtensor<[4,64],f32> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[4,64],f32> func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[4,64],f32> { %int0 = torch.constant.int 0 %0 = torch.aten.contiguous %arg0, %int0 : !torch.vtensor<[4,64],f32>, !torch.int -> !torch.vtensor<[4,64],f32> @@ -294,8 +292,8 @@ func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,1],si32>) -> !torch.vtensor<[3,4],si32> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,1],si32> -> tensor<3x1xi32> +// CHECK-DAG: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,1],si32> -> tensor<3x1xi32> // CHECK: %[[VAL_2:.*]] = stablehlo.broadcast_in_dim %[[VAL_1:.*]], dims = [0, 1] : (tensor<3x1xi32>) -> tensor<3x4xi32> // CHECK: %[[VAL_3:.*]] = stablehlo.shift_left %[[VAL_0:.*]], %[[VAL_2:.*]] : tensor<3x4xi32> // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3:.*]] : tensor<3x4xi32> -> !torch.vtensor<[3,4],si32> @@ -310,8 +308,8 @@ func.func @torch.aten.bitwise_left_shift.Tensor(%arg0: !torch.vtensor<[3,4],si32 // CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3,4],si64>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[3,4],si64>) -> !torch.vtensor<[3,4],si64> { -// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64> -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64> +// CHECK-DAG: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64> +// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1:.*]] : !torch.vtensor<[3,4],si64> -> tensor<3x4xi64> // CHECK: %[[VAL_2:.*]] = stablehlo.shift_right_arithmetic %[[VAL_0:.*]], %[[VAL_1:.*]] : tensor<3x4xi64> // CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2:.*]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> // CHECK: return %[[VAL_3:.*]] : !torch.vtensor<[3,4],si64> @@ -319,3 +317,25 @@ func.func @torch.aten.bitwise_right_shift.Tensor(%arg0: !torch.vtensor<[3,4],si6 %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si64>, !torch.vtensor<[3,4],si64> -> !torch.vtensor<[3,4],si64> return %0 : !torch.vtensor<[3,4],si64> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tril( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[2,3,5],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.int) -> !torch.vtensor<[2,3,5],f32> +// CHECK-DAG: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[2,3,5],f32> -> tensor<2x3x5xf32> +// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_i64 %[[ARG_1]] +// CHECK: %[[VAL_2:.*]] = stablehlo.iota dim = 1 : tensor<3x5xi64> +// CHECK: %[[VAL_3:.*]] = stablehlo.iota dim = 0 : tensor<3x5xi64> +// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_1]] : tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_3]], %[[VAL_4]] {broadcast_dimensions = array} : (tensor<3x5xi64>, tensor<1xi64>) -> tensor<3x5xi64> +// CHECK: %[[VAL_6:.*]] = stablehlo.compare LE, %[[VAL_2]], %[[VAL_5]], SIGNED : (tensor<3x5xi64>, tensor<3x5xi64>) -> tensor<3x5xi1> +// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [1, 2] : (tensor<3x5xi1>) -> tensor<2x3x5xi1> +// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2x3x5xf32> +// CHECK: %[[VAL_9:.*]] = stablehlo.select %[[VAL_7]], %[[VAL_0]], %[[VAL_8]] : tensor<2x3x5xi1>, tensor<2x3x5xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x3x5xf32> -> !torch.vtensor<[2,3,5],f32> +// CHECK: return %[[VAL_10:.*]] : !torch.vtensor<[2,3,5],f32> +func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int) -> !torch.vtensor<[2,3,5],f32> { + %0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32> + return %0 : !torch.vtensor<[2,3,5],f32> +} diff --git a/test/Conversion/TorchToStablehlo/elementwise.mlir b/test/Conversion/TorchToStablehlo/elementwise.mlir index ad249d971bbe..104f6e0d8761 100644 --- a/test/Conversion/TorchToStablehlo/elementwise.mlir +++ b/test/Conversion/TorchToStablehlo/elementwise.mlir @@ -103,8 +103,7 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK-LABEL: func.func @torch.aten.addscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -124,10 +123,8 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.addscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -149,8 +146,8 @@ func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.addtensor$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T2:.*]] = chlo.broadcast_add %[[T0]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> @@ -165,10 +162,9 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.addtensor$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -186,8 +182,8 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.addtensor$promote( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor) -> tensor // CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor, tensor) -> tensor @@ -204,8 +200,7 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK-LABEL: func.func @torch.aten.subscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -225,8 +220,7 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.rsubscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> @@ -246,10 +240,8 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.subscalar$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -271,8 +263,8 @@ func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.subtensor$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T2:.*]] = chlo.broadcast_subtract %[[T0]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> @@ -287,10 +279,9 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.subtensor$alpha( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -308,8 +299,8 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.subtensor$promote( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],si32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor) -> tensor // CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor, tensor) -> tensor @@ -326,8 +317,7 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK-LABEL: func.func @torch.aten.mulscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -344,8 +334,8 @@ func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.multensor$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T2:.*]] = chlo.broadcast_multiply %[[T0]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> @@ -359,8 +349,7 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.divscalar$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT9:.*]] = torch.constant.int 9 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] +// CHECK: %[[T1:.*]] = arith.constant 9 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -377,8 +366,8 @@ func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.divtensor$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> @@ -392,8 +381,7 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.gt.scalar( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] +// CHECK: %[[T1:.*]] = arith.constant 3 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor @@ -411,8 +399,8 @@ func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.gt.tensor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> // CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> @@ -425,8 +413,8 @@ func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK-LABEL: func.func @torch.aten.lt.tensor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> // CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> @@ -439,8 +427,8 @@ func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK-LABEL: func.func @torch.aten.eq.tensor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> // CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> @@ -453,8 +441,8 @@ func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK-LABEL: func.func @torch.aten.ne.tensor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> // CHECK: %[[T2:.*]] = chlo.broadcast_compare %[[T0]], %[[T1]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor<64xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1> @@ -500,8 +488,8 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ // CHECK-LABEL: func.func @torch.aten.addscalar$variable( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -521,9 +509,9 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK-LABEL: func.func @torch.aten.addtensor$variable( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG2:.*]]: !torch.float) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64> // CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> // CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor @@ -540,8 +528,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK-LABEL: func.func @torch.aten.mulscalar$variable( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -557,8 +545,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK-LABEL: func.func @torch.aten.divscalar$variable( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor @@ -574,8 +562,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK-LABEL: func.func @torch.aten.gt.scalar$variable( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> // CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : tensor<1xi64> // CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor @@ -592,8 +580,8 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "trunc" // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor @@ -612,8 +600,8 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32> // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "floor" // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor diff --git a/test/Conversion/TorchToStablehlo/gather.mlir b/test/Conversion/TorchToStablehlo/gather.mlir index df29bf1d4cca..14581bcc658c 100644 --- a/test/Conversion/TorchToStablehlo/gather.mlir +++ b/test/Conversion/TorchToStablehlo/gather.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: func.func @torch.aten.index_select$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,4],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index @@ -22,8 +22,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1 // CHECK-LABEL: func.func @torch.aten.embedding$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?],si64>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],si64> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?],si64> -> tensor // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[INT:.*]]-1 = torch.constant.int -1 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 @@ -44,8 +44,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic // CHECK-LABEL: func.func @torch.aten.embedding$rank_two_indices( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,1],si64>) -> !torch.vtensor<[?,1,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,1],si64> -> tensor // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[INT:.*]]-1 = torch.constant.int -1 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index 7f253a98df04..69ec4e2410eb 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -2,8 +2,8 @@ // CHECK-LABEL: func.func @torch.aten.mm$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> @@ -17,8 +17,8 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.mm$basic$dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<3x?xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> @@ -32,19 +32,16 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: // CHECK-LABEL: func.func @torch.aten.bmm$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[10,3,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[10,3,4],f32> -> tensor<10x3x4xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[10,4,5],f32> -> tensor<10x4x5xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<10x4x5xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<10x4x5xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xindex>) -> tensor<10x4x5xf32> // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> @@ -58,19 +55,16 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK-LABEL: func.func @torch.aten.bmm$basic$dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,4],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,4,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> @@ -84,19 +78,16 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg // CHECK-LABEL: func.func @torch.aten.matmul$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256,120],f32>, %[[ARG1:.*]]: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256,120],f32> -> tensor<256x120xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[4,120,256],f32> -> tensor<4x120x256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor<4x120x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256x120xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xindex>) -> tensor<4x256x120xf32> // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T9]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> @@ -110,19 +101,16 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK-LABEL: func.func @torch.aten.matmul$basic$dynamic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,?,256],f32> -> tensor<4x?x256xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x?x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x?xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xindex>) -> tensor<4x256x?xf32> // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> @@ -136,16 +124,14 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, // CHECK-LABEL: func.func @torch.aten.matmul$3dx1d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1,?,256],f32> -> tensor<1x?x256xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<1x?x256xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex> +// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor<1x256xf32> // CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T0]], %[[T7]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> @@ -159,16 +145,14 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: // CHECK-LABEL: func.func @torch.aten.matmul$1dx3d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,256,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T1]], %[[C0]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 -// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = tensor.from_elements %[[T2]], %[[T4]] : tensor<2xindex> +// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xindex>) -> tensor // CHECK: %[[T8:.*]] = stablehlo.dot_general %[[T7]], %[[T1]], batching_dims = [0] x [0], contracting_dims = [1] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor to tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> @@ -182,8 +166,8 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-LABEL: func.func @torch.aten.matmul$2dx1d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> @@ -197,8 +181,8 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t // CHECK-LABEL: func.func @torch.aten.matmul$1dx2d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> @@ -212,8 +196,8 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-LABEL: func.func @torch.aten.matmul$1dx1d( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> +// CHECK-DAG: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> @@ -227,19 +211,16 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-LABEL: func.func @torch.aten.matmul$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor // CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[C0_0:.*]] = arith.constant 0 : index // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256x256xf32> -// CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> -// CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 -// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = tensor.from_elements %[[T2]], %[[T4]], %[[T6]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xindex>) -> tensor // CHECK: %[[T10:.*]] = stablehlo.dot_general %[[T0]], %[[T9]], batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> @@ -254,7 +235,7 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.mm$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor +// CHECK-DAG: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<256x256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor @@ -271,14 +252,13 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.convolution( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor +// CHECK-DAG: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-DAG: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor // CHECK: %[[T_2:.*]] = torch.constant.none // CHECK: %[[T_4:.*]] = torch.constant.int 2 // CHECK: %[[T_5:.*]] = torch.constant.int 1 // CHECK: %[[T_6:.*]] = torch.constant.int 4 // CHECK: %[[T_7:.*]] = torch.constant.int 3 -// CHECK: %[[T_8:.*]] = torch_c.to_i64 %[[T_7]] // CHECK: %[[T_9:.*]] = torch.prim.ListConstruct %[[T_4]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_10:.*]] = torch.prim.ListConstruct %[[T_6]], %[[T_4]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list @@ -308,14 +288,12 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.convolution$bias( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, // CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor -// CHECK: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor +// CHECK-DAG: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK-DAG: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor +// CHECK-DAG: %[[T_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?],f32> -> tensor // CHECK: %int2 = torch.constant.int 2 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int4 = torch.constant.int 4 -// CHECK: %int3 = torch.constant.int 3 -// CHECK: %[[T_3:.*]] = torch_c.to_i64 %int3 // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int4, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -325,10 +303,9 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor -// CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 -// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 -// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> -// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_9]], %[[VAL_0]], %[[VAL_0]] : tensor<3xindex> +// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xindex>) -> tensor // CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor // CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> @@ -351,13 +328,12 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK-LABEL: func.func @torch.aten.convolution$transposed_basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { -// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK-DAG: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK-DAG: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x4x3x3xf32>) -> tensor<3x3x4x2xf32> @@ -382,13 +358,12 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK-LABEL: func.func @torch.aten.convolution$transposed_stride( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { -// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK-DAG: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK-DAG: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -417,13 +392,12 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK-LABEL: func.func @torch.aten.convolution$transposed_outputpadding( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { -// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> +// CHECK-DAG: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK-DAG: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %int2 = torch.constant.int 2 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list @@ -452,39 +426,31 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK-LABEL: func.func @torch.aten.convolution$transposed_groups( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { -// CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> -// CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32> +// CHECK-DAG: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> +// CHECK-DAG: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %none = torch.constant.none // CHECK: %int0 = torch.constant.int 0 // CHECK: %int1 = torch.constant.int 1 // CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[T_2:.*]] = torch_c.to_i64 %int2 -// CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_6:.*]] = stablehlo.transpose %[[T_1]], dims = [2, 3, 1, 0] : (tensor<2x2x3x3xf32>) -> tensor<3x3x2x2xf32> -// CHECK: %[[T_7:.*]] = stablehlo.reverse %6, dims = [0, 1] : tensor<3x3x2x2xf32> +// CHECK: %[[T_7:.*]] = stablehlo.reverse %[[T_6]], dims = [0, 1] : tensor<3x3x2x2xf32> // CHECK: %c0 = arith.constant 0 : index // CHECK: %dim = tensor.dim %[[T_7]], %c0 : tensor<3x3x2x2xf32> -// CHECK: %[[T_8:.*]] = arith.index_cast %dim : index to i64 // CHECK: %c1 = arith.constant 1 : index // CHECK: %dim_0 = tensor.dim %[[T_7]], %c1 : tensor<3x3x2x2xf32> -// CHECK: %[[T_9:.*]] = arith.index_cast %dim_0 : index to i64 // CHECK: %c2 = arith.constant 2 : index // CHECK: %dim_1 = tensor.dim %[[T_7]], %c2 : tensor<3x3x2x2xf32> -// CHECK: %[[T_10:.*]] = arith.index_cast %dim_1 : index to i64 // CHECK: %c3 = arith.constant 3 : index // CHECK: %dim_2 = tensor.dim %[[T_7]], %c3 : tensor<3x3x2x2xf32> -// CHECK: %[[T_11:.*]] = arith.index_cast %dim_2 : index to i64 -// CHECK: %c2_i64 = arith.constant 2 : i64 -// CHECK: %[[T_12:.*]] = arith.divsi %[[T_11]], %c2_i64 : i64 -// CHECK: %[[T_13:.*]] = arith.muli %[[T_10]], %c2_i64 : i64 -// CHECK: %from_elements = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_10]], %c2_i64, %[[T_12]] : tensor<5xi64> -// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xi64>) -> tensor<3x3x2x2x1xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[T_12:.*]] = arith.divsi %dim_2, %[[C2]] : index +// CHECK: %[[T_13:.*]] = arith.muli %dim_1, %[[C2]] : index +// CHECK: %from_elements = tensor.from_elements %dim, %dim_0, %dim_1, %[[C2]], %[[T_12]] : tensor<5xindex> +// CHECK: %[[T_14:.*]] = stablehlo.dynamic_reshape %[[T_7]], %from_elements : (tensor<3x3x2x2xf32>, tensor<5xindex>) -> tensor<3x3x2x2x1xf32> // CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> -// CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> -// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> +// CHECK: %[[from_elements_3:.*]] = tensor.from_elements %dim, %dim_0, %[[T_13]], %[[T_12]] : tensor<4xindex> +// CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %[[from_elements_3]] : (tensor<3x3x2x2x1xf32>, tensor<4xindex>) -> tensor<3x3x4x1xf32> // CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index 156c3ff51be2..f44d51c9fff7 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -83,18 +83,15 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T5:.*]] = stablehlo.constant dense<0xFF800000> : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T7:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T8:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64 -// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64> -// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor -// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex> +// CHECK: %[[T9:.*]] = arith.muli %[[DIM_1]], %[[DIM_0]] : index +// CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[DIM]], %[[T9]] : tensor<2xindex> +// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xindex>) -> tensor +// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xindex>) -> tensor // CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor // CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) <{padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): @@ -106,8 +103,8 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor // CHECK: }) : (tensor, tensor, tensor, tensor) -> (tensor, tensor) -// CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> -// CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> +// CHECK-DAG: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK-DAG: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64>) { %int3 = torch.constant.int 3 @@ -146,18 +143,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor -// CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 // CHECK: %[[IDX_1:.*]] = arith.constant 1 : index // CHECK: %[[VAL_10:.*]] = tensor.dim %[[VAL_1]], %[[IDX_1]] : tensor -// CHECK: %[[VAL_11:.*]] = arith.index_cast %[[VAL_10]] : index to i64 // CHECK: %[[IDX_2:.*]] = arith.constant 2 : index // CHECK: %[[VAL_12:.*]] = tensor.dim %[[VAL_1]], %[[IDX_2]] : tensor -// CHECK: %[[VAL_13:.*]] = arith.index_cast %[[VAL_12]] : index to i64 // CHECK: %[[IDX_3:.*]] = arith.constant 3 : index // CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor -// CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 -// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> -// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_8]], %[[VAL_10]], %[[VAL_12]], %[[VAL_14]] : tensor<4xindex> +// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xindex>) -> tensor // CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) // CHECK{LITERAL}: <{padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array}> ({ diff --git a/test/Conversion/TorchToStablehlo/scatter.mlir b/test/Conversion/TorchToStablehlo/scatter.mlir index fe8ffb9ee205..937c14a69245 100644 --- a/test/Conversion/TorchToStablehlo/scatter.mlir +++ b/test/Conversion/TorchToStablehlo/scatter.mlir @@ -2,25 +2,23 @@ // CHECK-LABEL: func.func @forward( // CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],si64>, %[[ARG_1:.*]]: !torch.vtensor<[?,?],si64>, %[[ARG_2:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[VAR_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],si64> -> tensor -// CHECK: %[[VAR_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],si64> -> tensor -// CHECK: %[[VAR_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK-DAG: %[[VAR_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK-DAG: %[[VAR_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK-DAG: %[[VAR_2:.*]] = torch_c.to_builtin_tensor %[[ARG_2]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %int0 = torch.constant.int 0 // CHECK: %[[INDEX_0:.*]] = arith.constant 0 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[VAR_1]], %[[INDEX_0]] : tensor -// CHECK: %[[VAR_3:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[INDEX_1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %1, %[[INDEX_1]] : tensor -// CHECK: %[[VAR_4:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i64 -// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : i64 -// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xi64> -// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xi64> -// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]] : tensor<2xi64> -// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor -// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[VAR_3]], %[[VAR_4]], %[[CONSTANT_1]] : tensor<3xi64> -// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xi64>) -> tensor +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index +// CHECK: %[[FE_:.*]] = tensor.from_elements %[[CONSTANT_0]], %[[CONSTANT_0]] : tensor<2xindex> +// CHECK: %[[FE_1:.*]] = tensor.from_elements %[[CONSTANT_1]], %[[CONSTANT_1]] : tensor<2xindex> +// CHECK: %[[FE_2:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]] : tensor<2xindex> +// CHECK: %[[VAR_5:.*]] = stablehlo.real_dynamic_slice %[[VAR_2]], %[[FE_]], %[[FE_2]], %[[FE_1]] : (tensor, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor +// CHECK: %[[FE_3:.*]] = tensor.from_elements %[[DIM_0]], %[[DIM_1]], %[[CONSTANT_1]] : tensor<3xindex> +// CHECK: %[[VAR_6:.*]] = stablehlo.dynamic_reshape %1, %[[FE_3]] : (tensor, tensor<3xindex>) -> tensor +// CHECK: %[[VAR_7:.*]] = stablehlo.dynamic_iota %[[FE_3]], dim = 1 : (tensor<3xindex>) -> tensor // CHECK: %[[VAR_8:.*]] = stablehlo.concatenate %[[VAR_6]], %[[VAR_7]], dim = 2 : (tensor, tensor) -> tensor // CHECK: %[[VAR_9:.*]] = "stablehlo.scatter"(%[[VAR_0]], %[[VAR_8]], %[[VAR_5]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false}> ({ // CHECK: ^bb0(%arg3: tensor, %[[ARG_4:.*]]: tensor): diff --git a/test/Conversion/TorchToStablehlo/view_like.mlir b/test/Conversion/TorchToStablehlo/view_like.mlir index ab54d2764b66..5e08f2d16c45 100644 --- a/test/Conversion/TorchToStablehlo/view_like.mlir +++ b/test/Conversion/TorchToStablehlo/view_like.mlir @@ -3,12 +3,9 @@ // CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT10:.*]] = torch.constant.int 10 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT10]] +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T3:.*]] = arith.constant 10 : i64 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -42,7 +39,7 @@ // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -58,12 +55,9 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK-LABEL: func.func @torch.aten.slice.strided.static$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] -// CHECK: %[[INT9223372036854775807:.*]] = torch.constant.int 9223372036854775807 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT9223372036854775807]] +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 2 : i64 +// CHECK: %[[T3:.*]] = arith.constant 9223372036854775807 : i64 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<4x65x256xf32> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -97,7 +91,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32> func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { @@ -113,12 +107,9 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK-LABEL: func.func @torch.aten.slice.last$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 1 : i64 +// CHECK: %[[T3:.*]] = arith.constant -1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -152,7 +143,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { @@ -168,12 +159,9 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK-LABEL: func.func @torch.aten.slice.last.static$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 -// CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT]]-1 +// CHECK: %[[T1:.*]] = arith.constant 0 : i64 +// CHECK: %[[T2:.*]] = arith.constant 1 : i64 +// CHECK: %[[T3:.*]] = arith.constant -1 : i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -207,7 +195,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS_6]], %[[FROM_ELEMENTS_7]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32> func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { @@ -224,8 +212,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 2 : i64 // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor @@ -247,7 +234,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -264,8 +251,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT2]] +// CHECK: %[[T1:.*]] = arith.constant 2 : i64 // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<4x65x256xf32> @@ -287,7 +273,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS_6]] : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> // CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32> func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { @@ -310,11 +296,12 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[T3:.*]] = torch_c.to_i64 %[[INT224]] // CHECK: %[[T4:.*]] = shape.shape_of %[[T0]] : tensor -> tensor<4xindex> // CHECK: %[[T5:.*]] = shape.num_elements %[[T4]] : tensor<4xindex> -> index -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T6:.*]] = stablehlo.compute_reshape_shape %[[T5]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> -// CHECK: %[[T7:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T6]] : (tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,224],f32> -// CHECK: return %[[T8]] : !torch.vtensor<[?,224],f32> +// CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : index to i64 +// CHECK: %[[T7:.*]] = arith.divui %[[T6]], %[[T3]] : i64 +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T7]], %[[T3]] : tensor<2xi64> +// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,224],f32> +// CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32> func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { %int-1 = torch.constant.int -1 %int224 = torch.constant.int 224 @@ -339,11 +326,14 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[INT64]] // CHECK: %[[T6:.*]] = shape.shape_of %[[T0]] : tensor -> tensor<5xindex> // CHECK: %[[T7:.*]] = shape.num_elements %[[T6]] : tensor<5xindex> -> index -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> -// CHECK: %[[T8:.*]] = stablehlo.compute_reshape_shape %[[T7]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[T9:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T8]] : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,120,4,64],f32> -// CHECK: return %[[T10]] : !torch.vtensor<[?,120,4,64],f32> +// CHECK: %[[T8:.*]] = arith.index_cast %[[T7]] : index to i64 +// CHECK: %[[T9:.*]] = arith.divui %[[T8]], %[[T3]] : i64 +// CHECK: %[[T10:.*]] = arith.divui %[[T9]], %[[T4]] : i64 +// CHECK: %[[T11:.*]] = arith.divui %[[T10]], %[[T5]] : i64 +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T11]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> +// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor -> !torch.vtensor<[?,120,4,64],f32> +// CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32> func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { %int-1 = torch.constant.int -1 %int120 = torch.constant.int 120 @@ -389,15 +379,25 @@ func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vt // ----- // CHECK-LABEL: func.func @torch.aten.squeeze.dim$0$static( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[T1:.*]] = torch_c.from_builtin_tensor %[[T0]] : tensor<2x1x2x1x2xf32> -> !torch.vtensor<[2,1,2,1,2],f32> -// CHECK: return %[[T1]] : !torch.vtensor<[2,1,2,1,2],f32> -func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,1,2,1,2],f32> { - %int0 = torch.constant.int 0 - %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,1,2,1,2],f32> - return %0 : !torch.vtensor<[2,1,2,1,2],f32> +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3:.*]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex> +// CHECK: %[[T1:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<4xindex>) -> tensor<2x2x1x2xf32> +// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<2x2x1x2xf32> -> !torch.vtensor<[2,2,1,2],f32> +// CHECK: return %[[T2]] : !torch.vtensor<[2,2,1,2],f32> +func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,1,2],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1,2,1,2],f32>, !torch.int -> !torch.vtensor<[2,2,1,2],f32> + return %0 : !torch.vtensor<[2,2,1,2],f32> } // ----- @@ -408,18 +408,14 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32> func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { @@ -436,18 +432,14 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> ! // CHECK: %[[INT:.*]]-2 = torch.constant.int -2 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<4xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32> func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { @@ -463,15 +455,12 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,1,2,1,2],f32> -> tensor<2x1x2x1x2xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C4:.*]] = arith.constant 4 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64> -// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]] : tensor<3xindex> +// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xindex>) -> tensor<2x2x2xf32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> // CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32> func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { @@ -487,19 +476,15 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xindex>) -> tensor<1x?x?x?x?xf32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { @@ -516,19 +501,15 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[C1_I64]], %[[DIM_0]], %[[DIM_1]], %[[DIM_2]] : tensor<5xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { @@ -545,19 +526,15 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[INT:.*]]-2 = torch.constant.int -2 // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor -// CHECK: %[[T1:.*]] = arith.index_cast %[[DIM]] : index to i64 // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM_0:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor -// CHECK: %[[T2:.*]] = arith.index_cast %[[DIM_0]] : index to i64 // CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C2]] : tensor -// CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C3]] : tensor -// CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 -// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[C1_I64:.*]] = arith.constant 1 : index +// CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[DIM]], %[[DIM_0]], %[[DIM_1]], %[[C1_I64]], %[[DIM_2]] : tensor<5xindex> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xindex>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32> func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index c6369e6fa769..941d710e4f2e 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -26,6 +26,219 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // ----- +// CHECK-DAG: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<4x8xf32>) -> tensor<1x4x8xf32> +// CHECK-DAG: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<1x4x8xf32>, tensor<1x8x16xf32>) -> tensor<1x4x16xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x4x16xf32>) -> tensor<4x16xf32> +func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> + return %0 : !torch.vtensor<[4,16],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<6xf32>) -> tensor<1x6xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6xf32>) -> tensor<6x1xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor +func.func @torch.aten.matmul_1d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch.vtensor<[6],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[6],f32>, !torch.vtensor<[6],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<6xf32>) -> tensor<1x6xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x6xf32>) -> tensor<1x1x6xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x6xf32>, tensor<1x6x1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> +func.func @torch.aten.matmul_12d(%arg0 : !torch.vtensor<[6],f32>, %arg1 : !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[6],f32>, !torch.vtensor<[6,1],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6xf32>) -> tensor<6x1xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<2x6xf32>) -> tensor<1x2x6xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_4]] : (tensor<1x2x6xf32>, tensor<1x6x1xf32>) -> tensor<1x2x1xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<1x2x1xf32>) -> tensor<2xf32> +func.func @torch.aten.matmul_21d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6],f32>) -> !torch.vtensor<[2],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6],f32> -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<2x6xf32>) -> tensor<1x2x6xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x6x8xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<1x2x6xf32>, tensor<1x6x8xf32>) -> tensor<1x2x8xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x8xf32>) -> tensor<2x8xf32> +func.func @torch.aten.mm_2d(%arg0 : !torch.vtensor<[2,6],f32>, %arg1 : !torch.vtensor<[6,8],f32>) -> !torch.vtensor<[2,8],f32> { + %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[2,6],f32>, !torch.vtensor<[6,8],f32> -> !torch.vtensor<[2,8],f32> + return %0 : !torch.vtensor<[2,8],f32> +} + +// ----- + +// CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xbf16>, tensor<1x8x16xbf16>) -> tensor<1x4x16xf32> +func.func @torch.aten.mm_bf16(%arg0: !torch.vtensor<[4,8],bf16>, %arg1: !torch.vtensor<[8,16],bf16>) -> !torch.vtensor<[4,16],f32> { + %false = torch.constant.bool false + %none = torch.constant.none + %int6 = torch.constant.int 6 + %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[4,8],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f32> + %1 = torch.aten.to.dtype %arg1, %int6, %false, %false, %none : !torch.vtensor<[8,16],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f32> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> + return %2 : !torch.vtensor<[4,16],f32> +} + +// ----- + +// CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf32> +func.func @torch.aten.mm_f16(%arg0: !torch.vtensor<[4,8],f16>, %arg1: !torch.vtensor<[8,16],f16>) -> !torch.vtensor<[4,16],f32> { + %false = torch.constant.bool false + %none = torch.constant.none + %int6 = torch.constant.int 6 + %0 = torch.aten.to.dtype %arg0, %int6, %false, %false, %none : !torch.vtensor<[4,8],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f32> + %1 = torch.aten.to.dtype %arg1, %int6, %false, %false, %none : !torch.vtensor<[8,16],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f32> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[4,16],f32> + return %2 : !torch.vtensor<[4,16],f32> +} + +// ----- + +// CHECK: tosa.matmul{{.*}} : (tensor<1x4x8xi8>, tensor<1x8x16xi8>) -> tensor<1x4x16xi32> +func.func @torch.aten.mm_i8(%arg0: !torch.vtensor<[4,8],si8>, %arg1: !torch.vtensor<[8,16],si8>) -> !torch.vtensor<[4,16],si32> { + %false = torch.constant.bool false + %none = torch.constant.none + %int3 = torch.constant.int 3 + %0 = torch.aten.to.dtype %arg0, %int3, %false, %false, %none : !torch.vtensor<[4,8],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],si32> + %1 = torch.aten.to.dtype %arg1, %int3, %false, %false, %none : !torch.vtensor<[8,16],si8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],si32> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si32>, !torch.vtensor<[8,16],si32> -> !torch.vtensor<[4,16],si32> + return %2 : !torch.vtensor<[4,16],si32> +} + +// ----- + +// expected-error @+1 {{invalid dtype 'si48' for !torch.tensor type}} +func.func @torch.aten.mm_i16(%arg0: !torch.vtensor<[4,8],si16>, %arg1: !torch.vtensor<[8,16],si16>) -> !torch.vtensor<[4,16],si48> { + %false = torch.constant.bool false + %none = torch.constant.none + %int3 = torch.constant.int 3 + %0 = torch.aten.to.dtype %arg0, %int3, %false, %false, %none : !torch.vtensor<[4,8],si16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],si48> + %1 = torch.aten.to.dtype %arg1, %int3, %false, %false, %none : !torch.vtensor<[8,16],si16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],si48> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],si48>, !torch.vtensor<[8,16],si48> -> !torch.vtensor<[4,16],si48> + return %2 : !torch.vtensor<[4,16],si48> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.cast %{{[0-9]+}} : (tensor<4x8xf32>) -> tensor<4x8xf16> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %{{[0-9]+}} : (tensor<8x16xf32>) -> tensor<8x16xf16> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4x8xf16>) -> tensor<1x4x8xf16> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<8x16xf16>) -> tensor<1x8x16xf16> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_4]], %[[VAL_5]] : (tensor<1x4x8xf16>, tensor<1x8x16xf16>) -> tensor<1x4x16xf16> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x16xf16>) -> tensor<4x16xf16> + +func.func @torch.aten.mm_f32_to_f16(%arg0: !torch.vtensor<[4,8],f32>, %arg1: !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[4,16],f16> { + %false = torch.constant.bool false + %none = torch.constant.none + %int5 = torch.constant.int 5 + %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[4,8],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[4,8],f16> + %1 = torch.aten.to.dtype %arg1, %int5, %false, %false, %none : !torch.vtensor<[8,16],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[8,16],f16> + %2 = torch.aten.mm %0, %1 : !torch.vtensor<[4,8],f16>, !torch.vtensor<[8,16],f16> -> !torch.vtensor<[4,16],f16> + return %2 : !torch.vtensor<[4,16],f16> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x10x6x2xf32>) -> tensor<100x6x2xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<10x10x2x6xf32>) -> tensor<100x2x6xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] : (tensor<100x6x2xf32>, tensor<100x2x6xf32>) -> tensor<100x6x6xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<100x6x6xf32>) -> tensor<10x10x6x6xf32> +func.func @torch.aten.matmul_4d(%arg0 : !torch.vtensor<[10,10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[10,10,6,6],f32> + return %0 : !torch.vtensor<[10,10,6,6],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<10x6x2xf32>) -> tensor<1x10x6x2xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %[[VAL_2]], %[[VAL_3]] : (tensor<1x10x6x2xf32>, tensor<4xi32>) -> tensor<10x1x6x2xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<10x1x6x2xf32>) -> tensor<10x6x2xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.transpose %0, %[[VAL_6]] : (tensor<10x10x2x6xf32>, tensor<4xi32>) -> tensor<10x2x10x6xf32> +// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<10x2x10x6xf32>) -> tensor<10x2x60xf32> +// CHECK-NEXT: %[[VAL_9:.+]] = tosa.matmul %[[VAL_5]], %[[VAL_8]] : (tensor<10x6x2xf32>, tensor<10x2x60xf32>) -> tensor<10x6x60xf32> +// CHECK-NEXT: %[[VAL_10:.+]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<10x6x60xf32>) -> tensor<10x6x10x6xf32> +// CHECK-NEXT: %[[VAL_11:.+]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_12:.+]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<10x6x10x6xf32>, tensor<4xi32>) -> tensor<10x10x6x6xf32> +func.func @torch.aten.matmul_4d_broadcast(%arg0 : !torch.vtensor<[10,6,2],f32>, %arg1 : !torch.vtensor<[10,10,2,6],f32>) -> !torch.vtensor<[10,10,6,6],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[10,6,2],f32>, !torch.vtensor<[10,10,2,6],f32> -> !torch.vtensor<[10,10,6,6],f32> + return %0 : !torch.vtensor<[10,10,6,6],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<4x1x5x6xf32>) -> tensor<1x20x6xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = "tosa.const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_4:.+]] = tosa.transpose %0, %[[VAL_3]] : (tensor<1x3x6x7xf32>, tensor<4xi32>) -> tensor<6x1x3x7xf32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x1x3x7xf32>) -> tensor<1x6x21xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.matmul %[[VAL_2]], %[[VAL_5]] : (tensor<1x20x6xf32>, tensor<1x6x21xf32>) -> tensor<1x20x21xf32> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x20x21xf32>) -> tensor<4x5x3x7xf32> +// CHECK-NEXT: %[[VAL_8:.+]] = "tosa.const"() <{value = dense<[0, 2, 1, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-NEXT: %[[VAL_9:.+]] = tosa.transpose %[[VAL_7]], %[[VAL_8]] : (tensor<4x5x3x7xf32>, tensor<4xi32>) -> tensor<4x3x5x7xf32> +func.func @torch.aten.matmul_4d_broadcast_2(%arg0 : !torch.vtensor<[4,1,5,6],f32>, %arg1 : !torch.vtensor<[1,3,6,7],f32>) -> !torch.vtensor<[4,3,5,7],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[4,1,5,6],f32>, !torch.vtensor<[1,3,6,7],f32> -> !torch.vtensor<[4,3,5,7],f32> + return %0 : !torch.vtensor<[4,3,5,7],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.reshape %0 {new_shape = array} : (tensor<8x16xf32>) -> tensor<1x8x16xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.reshape %1 {new_shape = array} : (tensor<100x4x8xf32>) -> tensor<1x400x8xf32> +// CHECK-NEXT: %[[VAL_4:.+]] = "tosa.const"() <{value = dense<[1, 0, 2]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-NEXT: %[[VAL_5:.+]] = tosa.transpose %[[VAL_2]], %[[VAL_4]] : (tensor<1x8x16xf32>, tensor<3xi32>) -> tensor<8x1x16xf32> +// CHECK-NEXT: %[[VAL_6:.+]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<8x1x16xf32>) -> tensor<1x8x16xf32> +// CHECK-NEXT: %[[VAL_7:.+]] = tosa.matmul %[[VAL_3]], %[[VAL_6]] : (tensor<1x400x8xf32>, tensor<1x8x16xf32>) -> tensor<1x400x16xf32> +// CHECK-NEXT: %[[VAL_8:.+]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x400x16xf32>) -> tensor<100x4x16xf32> +func.func @torch.aten.matmul_3d_broadcast(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[8,16],f32>) -> !torch.vtensor<[100,4,16],f32> { + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[8,16],f32> -> !torch.vtensor<[100,4,16],f32> + return %0 : !torch.vtensor<[100,4,16],f32> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.matmul %1, %0 : (tensor<100x4x8xf16>, tensor<100x8x16xf16>) -> tensor<100x4x16xf16> +func.func @torch.aten.bmm_3d_fp16(%arg0 : !torch.vtensor<[100,4,8],f16>, %arg1 : !torch.vtensor<[100,8,16],f16>) -> !torch.vtensor<[100,4,16],f16> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f16>, !torch.vtensor<[100,8,16],f16> -> !torch.vtensor<[100,4,16],f16> + return %0 : !torch.vtensor<[100,4,16],f16> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.matmul %1, %0 : (tensor<100x4x8xbf16>, tensor<100x8x16xbf16>) -> tensor<100x4x16xf32> +// CHECK-NEXT: %[[VAL_3:.+]] = tosa.cast %[[VAL_2]] : (tensor<100x4x16xf32>) -> tensor<100x4x16xbf16> +func.func @torch.aten.bmm_3d_bf16(%arg0 : !torch.vtensor<[100,4,8],bf16>, %arg1 : !torch.vtensor<[100,8,16],bf16>) -> !torch.vtensor<[100,4,16],bf16> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],bf16>, !torch.vtensor<[100,8,16],bf16> -> !torch.vtensor<[100,4,16],bf16> + return %0 : !torch.vtensor<[100,4,16],bf16> +} + +// ----- + +// CHECK: %[[VAL_2:.+]] = tosa.matmul %1, %0 : (tensor<100x4x8xf32>, tensor<100x8x16xf32>) -> tensor<100x4x16xf32> +func.func @torch.aten.bmm_3d_fp32(%arg0 : !torch.vtensor<[100,4,8],f32>, %arg1 : !torch.vtensor<[100,8,16],f32>) -> !torch.vtensor<[100,4,16],f32> { + %0 = torch.aten.bmm %arg0, %arg1 : !torch.vtensor<[100,4,8],f32>, !torch.vtensor<[100,8,16],f32> -> !torch.vtensor<[100,4,16],f32> + return %0 : !torch.vtensor<[100,4,16],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.relu$basic( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor @@ -45,12 +258,14 @@ func.func @torch.aten.relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 1.000000e-01 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_1]], %[[VAL_6]] : (tensor, tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_6]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_1]], %[[VAL_8]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.leaky_relu$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 1.000000e-01 @@ -157,14 +372,15 @@ func.func @torch.aten.reciprocal$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK-LABEL: func.func @torch.aten.add$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 @@ -177,14 +393,15 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.sub$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { %int1 = torch.constant.int 1 @@ -197,8 +414,8 @@ func.func @torch.aten.sub$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.mul$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_2]], %[[VAL_3]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> @@ -213,10 +430,10 @@ func.func @torch.aten.mul$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.div$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor, tensor) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } @@ -227,6 +444,35 @@ func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch // ----- +// CHECK-LABEL: func.func @torch.aten.rsqrt$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.rsqrt %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.rsqrt %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_mean_dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = tosa.reduce_sum %[[VAL_1]] {axis = 0 : i32} : (tensor) -> tensor<1x?x?x?xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x?x?x?xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.08420217E-19> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_7]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<1x1x1xf32>) -> tensor +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor -> !torch.vtensor<[?,?,?],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?],f32> +// CHECK: } func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { %dim0 = torch.constant.int 0 %reducedims = torch.prim.ListConstruct %dim0 : (!torch.int) -> !torch.list @@ -262,21 +508,24 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // ----- // CHECK-LABEL: func.func @test_linalg_vector_norm$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { -// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> -// CHECK: %[[ARG1:.*]] = torch.constant.float 2.000000e+00 -// CHECK: %[[ARG2:.*]] = torch.constant.int -1 -// CHECK: %[[ARG3:.*]] = torch.constant.bool true -// CHECK: %[[ARG4:.*]] = torch.constant.none -// CHECK: %[[ARG5:.*]] = torch.prim.ListConstruct %[[ARG2]] : (!torch.int) -> !torch.list -// CHECK: %[[ARG6:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[ARG7:.*]] = tosa.abs %[[ARG0_BUILTIN]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> -// CHECK: %[[ARG8:.*]] = tosa.pow %[[ARG7]], %[[ARG6]] : (tensor<3x151x64xf32>, tensor) -> tensor<3x151x64xf32> -// CHECK: %[[ARG9:.*]] = tosa.reduce_sum %[[ARG8]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> -// CHECK: %[[ARG10:.*]] = tosa.reciprocal %[[ARG6]] : (tensor) -> tensor -// CHECK: %[[ARG11:.*]] = tosa.pow %[[ARG9]], %[[ARG10]] : (tensor<3x151x1xf32>, tensor) -> tensor<3x151x1xf32> -// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ARG11]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> -// CHECK: return %[[RESULT]] : !torch.vtensor<[3,151,1],f32> +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool true +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_1]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_9]], %[[VAL_8]] : (tensor<3x151x64xf32>, tensor<1x1x1xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reduce_sum %[[VAL_10]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.pow %[[VAL_11]], %[[VAL_13]] : (tensor<3x151x1xf32>, tensor<1x1x1xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> +// CHECK: return %[[VAL_15]] : !torch.vtensor<[3,151,1],f32> +// CHECK: } func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) -> (!torch.vtensor<[3,151,1],f32>) { %float2.000000e00 = torch.constant.float 2.000000e+00 %int-1 = torch.constant.int -1 @@ -377,8 +626,8 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.maximum$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.maximum %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> @@ -393,8 +642,8 @@ func.func @torch.aten.maximum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // CHECK-LABEL: func.func @torch.aten.minimum$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.minimum %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> @@ -407,13 +656,14 @@ func.func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !to // ----- // CHECK-LABEL: func.func @torch.aten.pow.Tensor_Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_1]], %[[VAL_4]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %fp0 = torch.constant.float 3.123400e+00 @@ -430,10 +680,12 @@ func.func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) // CHECK: %[[VAL_3:.*]] = torch.constant.float 6.432100e+00 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<6.432100e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %other = torch.constant.float 3.123400e+00 @@ -444,19 +696,21 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // ----- -// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$float_int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor}> : () -> tensor // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_4]], %[[VAL_6]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_1]], %[[VAL_7]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_6]], %[[VAL_8]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[?,?],f32> // CHECK: } -func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +func.func @torch.aten.rsub.Scalar$float_int(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { %other = torch.constant.float 3.123400e+00 %alpha = torch.constant.int 1 %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32> @@ -468,8 +722,8 @@ func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !to // CHECK-LABEL: func.func @torch.aten.gt.Tensor$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> @@ -484,8 +738,8 @@ func.func @torch.aten.gt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.lt.Tensor$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> @@ -500,8 +754,8 @@ func.func @torch.aten.lt.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.eq.Tensor$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> @@ -545,14 +799,19 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !to // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<4x1xf32> // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_9]], %[[VAL_12]] : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.rsqrt %[[VAL_14]] : (tensor<4x1xf32>) -> tensor<4x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_10]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.add %[[VAL_17]], %[[VAL_11]] : (tensor<10x4x3xf32>, tensor<4x1xf32>) -> tensor<10x4x3xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[10,4,3],f32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_1]], %[[VAL_13]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_14]], %[[VAL_15]] : (tensor<1x4x1xf32>, tensor<1x1x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.rsqrt %[[VAL_19]] : (tensor<1x4x1xf32>) -> tensor<1x4x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_18]], %[[VAL_20]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_22:.*]] = tosa.mul %[[VAL_21]], %[[VAL_16]] {shift = 0 : i8} : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_23:.*]] = tosa.add %[[VAL_22]], %[[VAL_17]] : (tensor<10x4x3xf32>, tensor<1x4x1xf32>) -> tensor<10x4x3xf32> +// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<10x4x3xf32> -> !torch.vtensor<[10,4,3],f32> +// CHECK: return %[[VAL_24]] : !torch.vtensor<[10,4,3],f32> // CHECK: } func.func @torch.aten.native_batch_norm$basic(%arg0: !torch.vtensor<[10,4,3],f32> ) -> !torch.vtensor<[10,4,3],f32> { %0 = torch.vtensor.literal(dense<[5.000000e-01, 4.000000e-01, 3.000000e-01, 6.000000e-01]> : tensor<4xf32>) : !torch.vtensor<[4],f32> @@ -608,44 +867,46 @@ func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3 // ----- -// CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, -// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> +// CHECK-LABEL: func.func @torch.aten.native_layer_norm$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[2,2,3],f32>) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2,2,3],f32> -> tensor<2x2x3xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,2,3],f32> -> tensor<5x2x2x3xf32> // CHECK: %[[VAL_6:.*]] = torch.constant.float 5.000000e-01 // CHECK: %[[VAL_7:.*]] = torch.constant.int 3 // CHECK: %[[VAL_8:.*]] = torch.constant.int 2 // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_8]], %[[VAL_8]], %[[VAL_7]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<1.200000e+01> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_11:.*]] = tosa.reciprocal %[[VAL_10]] : (tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_17]], %[[VAL_17]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> -// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_22]], %[[VAL_11]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> -// CHECK: %[[VAL_26:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor -// CHECK: %[[VAL_27:.*]] = tosa.sub %[[VAL_3]], %[[VAL_16]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_28:.*]] = tosa.add %[[VAL_23]], %[[VAL_26]] : (tensor<5x1x1x1xf32>, tensor) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_29:.*]] = tosa.rsqrt %[[VAL_28]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> -// CHECK: %[[VAL_30:.*]] = tosa.mul %[[VAL_27]], %[[VAL_29]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_31:.*]] = tosa.mul %[[VAL_30]], %[[VAL_24]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_32:.*]] = tosa.add %[[VAL_31]], %[[VAL_25]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> -// CHECK: %[[VAL_33:.*]] = torch_c.from_builtin_tensor %[[VAL_32]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> -// CHECK: return %[[VAL_33]] : !torch.vtensor<[5,2,2,3],f32> -// CHECK: } -func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_5]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_19:.*]] = tosa.mul %[[VAL_18]], %[[VAL_18]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reduce_sum %[[VAL_19]] {axis = 3 : i32} : (tensor<5x2x2x3xf32>) -> tensor<5x2x2x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reduce_sum %[[VAL_20]] {axis = 2 : i32} : (tensor<5x2x2x1xf32>) -> tensor<5x2x1x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<5x2x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_23:.*]] = tosa.reshape %[[VAL_22]] {new_shape = array} : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_24:.*]] = tosa.mul %[[VAL_23]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_25:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_26:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<2x2x3xf32>) -> tensor<1x2x2x3xf32> +// CHECK: %[[VAL_27:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_27]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_29:.*]] = tosa.sub %[[VAL_5]], %[[VAL_17]] : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_30:.*]] = tosa.add %[[VAL_24]], %[[VAL_28]] : (tensor<5x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_31:.*]] = tosa.rsqrt %[[VAL_30]] : (tensor<5x1x1x1xf32>) -> tensor<5x1x1x1xf32> +// CHECK: %[[VAL_32:.*]] = tosa.mul %[[VAL_29]], %[[VAL_31]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<5x1x1x1xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_33:.*]] = tosa.mul %[[VAL_32]], %[[VAL_25]] {shift = 0 : i8} : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_34:.*]] = tosa.add %[[VAL_33]], %[[VAL_26]] : (tensor<5x2x2x3xf32>, tensor<1x2x2x3xf32>) -> tensor<5x2x2x3xf32> +// CHECK: %[[VAL_35:.*]] = torch_c.from_builtin_tensor %[[VAL_34]] : tensor<5x2x2x3xf32> -> !torch.vtensor<[5,2,2,3],f32> +// CHECK: return %[[VAL_35]] : !torch.vtensor<[5,2,2,3],f32> +// CHECK: } +func.func @torch.aten.native_layer_norm$basic(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,3],f32> , %arg2: !torch.vtensor<[2,2,3],f32> ) -> !torch.vtensor<[5,2,2,3],f32> { %float5.000000e-01 = torch.constant.float 5.000000e-01 %int3 = torch.constant.int 3 %int2 = torch.constant.int 2 @@ -659,8 +920,8 @@ func.func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor< // CHECK-LABEL: func.func @torch.aten.ne.Tensor$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = tosa.logical_not %[[VAL_4]] : (tensor) -> tensor // CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],i1> @@ -676,8 +937,8 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.logical_or$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.logical_or %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> @@ -696,8 +957,8 @@ func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: ! // CHECK: %[[VAL_3:.*]] = torch.constant.int 2 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi64>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x2xf32>, tensor<3xi32>) -> tensor<3x2x4xf32> // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x2x4xf32> -> !torch.vtensor<[3,2,4],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[3,2,4],f32> // CHECK: } @@ -715,8 +976,8 @@ func.func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4 // CHECK-LABEL: func.func @torch.aten.bitwise_and.Tensor$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK-DAG: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> @@ -801,10 +1062,8 @@ func.func @torch.aten.unsqueeze$negative_dim(%arg0: !torch.vtensor<[4,3],si32> ) // CHECK-LABEL: func.func @torch.aten.contiguous$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_0]] : !torch.vtensor<[?,?],f32> // CHECK: } func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> { %int0 = torch.constant.int 0 @@ -854,37 +1113,35 @@ func.func @torch.aten.dropout$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch // ----- // CHECK-LABEL: func.func @torch.aten.avg_pool2d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,7,7],f32>) -> !torch.vtensor<[1,512,1,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,7,7],f32> -> tensor<1x512x7x7xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 7 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false -// CHECK: %[[VAL_6:.*]] = torch.constant.bool true -// CHECK: %[[VAL_7:.*]] = torch.constant.none -// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_11]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> -// CHECK: %[[VAL_13:.*]] = tosa.avg_pool2d %[[VAL_12]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> -// CHECK: %[[VAL_16:.*]] = tensor.cast %[[VAL_15]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> -// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> -// CHECK: return %[[VAL_17]] : !torch.vtensor<[1,512,1,1],f32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_10]] : (tensor<1x512x7x7xf32>, tensor<4xi32>) -> tensor<1x7x7x512xf32> +// CHECK: %[[VAL_12:.*]] = tosa.avg_pool2d %[[VAL_11]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x512xf32>) -> tensor<1x1x1x512xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_12]], %[[VAL_13]] : (tensor<1x1x1x512xf32>, tensor<4xi32>) -> tensor<1x512x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x1x1xf32> to tensor<1x512x1x1xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x1x1xf32> -> !torch.vtensor<[1,512,1,1],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,1,1],f32> // CHECK: } func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) -> !torch.vtensor<[1,512,1,1],f32> { %int7 = torch.constant.int 7 %int1 = torch.constant.int 1 %int0 = torch.constant.int 0 %false = torch.constant.bool false - %true = torch.constant.bool true %none = torch.constant.none %kernel = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list - %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %true, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> + %0 = torch.aten.avg_pool2d %arg0, %kernel, %stride, %padding, %false, %false, %none : !torch.vtensor<[1,512,7,7],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,512,1,1],f32> return %0 : !torch.vtensor<[1,512,1,1],f32> } @@ -892,15 +1149,15 @@ func.func @torch.aten.avg_pool2d$basic(%arg0: !torch.vtensor<[1,512,7,7],f32> ) // CHECK-LABEL: @torch.aten.max.dim$basic( // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2x3xf32>) -// CHECK: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> -// CHECK: %[[VAL_TRUE:.*]] = torch.constant.bool true -// CHECK: %[[VAL_I2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> -// CHECK: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> -// CHECK: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> +// CHECK-DAG: %[[VAL_0:.*]] = torch_c.from_builtin_tensor %[[ARG0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> +// CHECK-DAG: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK-DAG: %[[VAL_TRUE:.*]] = torch.constant.bool true +// CHECK-DAG: %[[VAL_I2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.argmax %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK-DAG: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK-DAG: %[[VAL_6:.*]] = torch_c.to_builtin_tensor %[[VAL_5]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> // CHECK: return %[[VAL_6]] : tensor<3x2x1xf32> func.func @torch.aten.max.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> @@ -1026,29 +1283,55 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[1,128],i1>) -> !torch.vten return %0 : !torch.vtensor<[1,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.to.dtype$floatToInt( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.greater %[[VAL_8]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x5xf32>) -> tensor<3x5xi1> +// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_6]], %[[VAL_5]] : (tensor<3x5xi1>, tensor<3x5xf32>, tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<3x5xf32>) -> tensor<3x5xi64> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x5xi64> -> !torch.vtensor<[3,5],si64> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[3,5],si64> +// CHECK: } +func.func @torch.aten.to.dtype$floatToInt(%arg0: !torch.vtensor<[3,5],f32>) -> !torch.vtensor<[3,5],si64> { + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.vtensor<[3,5],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,5],si64> + return %0 : !torch.vtensor<[3,5],si64> + } + // ----- // CHECK-LABEL: func.func @torch.aten.gather( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],f32> -> tensor<1x4x3xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int -1 // CHECK: %[[VAL_5:.*]] = torch.constant.bool false -// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32> // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x4x2xi32>) -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32> // CHECK: %[[VAL_10:.*]] = tosa.concat %[[VAL_8]], %[[VAL_9]], %[[VAL_7]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32> -// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1x4x3xf32>) -> tensor<1x12x1xf32> // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<1x4x2x3xi32>) -> tensor<8x3xi32> // CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_12]], %[[VAL_13]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<3xi32>) -> tensor<8x3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.reduce_sum %[[VAL_14]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> -// CHECK: %[[VAL_17:.*]] = tosa.gather %[[VAL_11]], %[[VAL_16]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[1,4,2],f32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<8x3xi32>, tensor<1x3xi32>) -> tensor<8x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<8x1xi32>) -> tensor<1x8xi32> +// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_11]], %[[VAL_17]] : (tensor<1x12x1xf32>, tensor<1x8xi32>) -> tensor<1x8x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x8x1xf32>) -> tensor<1x4x2xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x4x2xf32> -> !torch.vtensor<[1,4,2],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,4,2],f32> // CHECK: } func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> { %int-1 = torch.constant.int -1 @@ -1061,15 +1344,16 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v // CHECK-LABEL: func.func @torch.aten.add$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2],si32>) -> !torch.vtensor<[2,2],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,2],si32> -> tensor<2x2xi32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 1 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.add %[[VAL_2]], %[[VAL_6]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<2x2xi32>) -> tensor<2x2xi64> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,2],si64> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x2xi32>, tensor<1x1xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_3]], %[[VAL_7]] : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<2x2xi32>) -> tensor<2x2xi64> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x2xi64> -> !torch.vtensor<[2,2],si64> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[2,2],si64> // CHECK: } func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torch.vtensor<[2, 2],si32>) -> !torch.vtensor<[2, 2],si64> { %int1 = torch.constant.int 1 @@ -1083,14 +1367,17 @@ func.func @torch.aten.add$basic(%arg0: !torch.vtensor<[2, 2],si32>, %arg1: !torc // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> // CHECK: %[[VAL_2:.*]] = torch.constant.int 1 // CHECK: %[[VAL_3:.*]] = torch.constant.int 256 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_4]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<1x1x128x128xi32>, tensor) -> tensor<1x1x128x128xi32> -// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_8]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<256> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi64> +// CHECK: %[[VAL_5_cast:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x1x1x1xi64>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5_cast]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1x1x1xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_9:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_9]], %[[VAL_8]] : (tensor<1x1x128x128xi32>, tensor<1x1x1x1xi32>) -> tensor<1x1x128x128xi32> +// CHECK: %[[VAL_11:.*]] = tosa.cast %[[VAL_10]] : (tensor<1x1x128x128xi32>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,1,128,128],si64> // CHECK: } func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { %int1 = torch.constant.int 1 @@ -1107,8 +1394,10 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.constant.int 100 // CHECK: %[[VAL_5:.*]] = torch.constant.int -16 -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<4x65x256xf32>) -> tensor<4x16x256xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> +// CHECK: %[[VAL_1r:.*]] = tosa.reshape +// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1r]] {size = array, start = array} : (tensor<4x65x1x256xf32>) -> tensor<4x16x1x256xf32> +// CHECK: %[[VAL_4r:.*]] = tosa.reshape +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4r]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32> // CHECK: } func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { @@ -1192,14 +1481,15 @@ func.func @torch.aten.clamp.float(%arg0: !torch.vtensor<[1,1,128,128],f32>) -> ! // CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.int 0 // CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_3]], %[[VAL_6]], %[[VAL_2]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_2]], %[[VAL_7]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { %int0 = torch.constant.int 0 @@ -1212,12 +1502,13 @@ func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f3 // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>, // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> -// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_4]], %[[VAL_5]], %[[VAL_3]] : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_4]], %[[VAL_6]], %[[VAL_5]] : (tensor<1x1x128x128xi1>, tensor<1x1x1x1xf32>, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { %0 = torch.aten.masked_fill.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> @@ -1242,12 +1533,13 @@ func.func @torch.aten.abs(%arg0: !torch.vtensor<[15,15],si64>) -> !torch.vtensor // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,12,5,5],f32>, // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,12,5,5],f32> -> tensor<1x12x5x5xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor) -> tensor<1x12x5x5xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,5,5],f32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,5,5],i1> -> tensor<1x1x5x5xi1> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_5]], %[[VAL_4]], %[[VAL_6]] : (tensor<1x1x5x5xi1>, tensor<1x12x5x5xf32>, tensor<1x1x1x1xf32>) -> tensor<1x12x5x5xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x5x5xf32> -> !torch.vtensor<[1,12,5,5],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,5,5],f32> // CHECK: } func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> @@ -1260,13 +1552,14 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 // CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_1]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_1]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], %[[VAL_7]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[2,4],f32> // CHECK: } func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { %int2 = torch.constant.int 2 @@ -1276,29 +1569,1964 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to // ----- -// CHECK-LABEL: func.func @forward( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK-LABEL: func.func @torch.aten.isclose$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> // CHECK: %[[VAL_4:.*]] = torch.constant.float 1.000000e-08 // CHECK: %[[VAL_5:.*]] = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_6:.*]] = torch.constant.bool false -// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_2]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_8:.*]] = tosa.abs %[[VAL_7]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_9:.*]] = tosa.abs %[[VAL_3]] : (tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_9]] {shift = 0 : i8} : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor -// CHECK: %[[VAL_13:.*]] = tosa.add %[[VAL_12]], %[[VAL_11]] : (tensor, tensor<5x5xf32>) -> tensor<5x5xf32> -// CHECK: %[[VAL_14:.*]] = tosa.greater_equal %[[VAL_13]], %[[VAL_8]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> -// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> -// CHECK: return %[[VAL_15]] : !torch.vtensor<[5,5],i1> -// CHECK: } -func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_3]], %[[VAL_2]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_12:.*]] = tosa.abs %[[VAL_11]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_13:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_9]], %[[VAL_13]] {shift = 0 : i8} : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_10]], %[[VAL_14]] : (tensor<1x1xf32>, tensor<5x5xf32>) -> tensor<5x5xf32> +// CHECK: %[[VAL_16:.*]] = tosa.greater_equal %[[VAL_15]], %[[VAL_12]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,5],i1> +// CHECK: } +func.func @torch.aten.isclose$basic(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> { %float1.000000e-08 = torch.constant.float 1.000000e-08 %float1.000000e-05 = torch.constant.float 1.000000e-05 %false = torch.constant.bool false %0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1> return %0 : !torch.vtensor<[5,5],i1> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sin$basic( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.sin %[[ARG_BUILTIN]] : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.sin$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.sin %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cos$basic( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[RESULT_BUILTIN:.*]] = tosa.cos %[[ARG_BUILTIN]] : (tensor) -> tensor +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cos$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.cos %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.__interpolate.size_list_scale_list.bilinear( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,16,135,240],f32> -> tensor<1x16x135x240xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.str "bilinear" +// CHECK: %[[VAL_5:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.float, !torch.float) -> !torch.list +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] : (tensor<1x16x135x240xf32>, tensor<4xi32>) -> tensor<1x135x240x16xf32> +// CHECK: %[[VAL_9:.*]] = tosa.resize %[[VAL_8]] {border = array, mode = "BILINEAR", offset = array, scale = array} : (tensor<1x135x240x16xf32>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_9]], %[[VAL_10]] : (tensor<1x270x480x16xf32>, tensor<4xi32>) -> tensor<1x16x270x480xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,16,270,480],f32> +// CHECK: } +func.func @torch.aten.__interpolate.size_list_scale_list.bilinear(%arg0: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %str = torch.constant.str "bilinear" + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.prim.ListConstruct %float2.000000e00, %float2.000000e00 : (!torch.float, !torch.float) -> !torch.list + %1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32> + return %1 : !torch.vtensor<[1,16,270,480],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.__interpolate.size_list_scale_list.nearest( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,16,135,240],f32> -> tensor<1x16x135x240xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = torch.constant.str "nearest" +// CHECK: %[[VAL_5:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.float, !torch.float) -> !torch.list +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_7]] : (tensor<1x16x135x240xf32>, tensor<4xi32>) -> tensor<1x135x240x16xf32> +// CHECK: %[[VAL_9:.*]] = tosa.resize %[[VAL_8]] {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} : (tensor<1x135x240x16xf32>) -> tensor<1x270x480x16xf32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_9]], %[[VAL_10]] : (tensor<1x270x480x16xf32>, tensor<4xi32>) -> tensor<1x16x270x480xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<1x16x270x480xf32> -> !torch.vtensor<[1,16,270,480],f32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[1,16,270,480],f32> +// CHECK: } +func.func @torch.aten.__interpolate.size_list_scale_list.nearest(%arg0: !torch.vtensor<[1,16,135,240],f32>) -> !torch.vtensor<[1,16,270,480],f32> { + %none = torch.constant.none + %false = torch.constant.bool false + %str = torch.constant.str "nearest" + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.prim.ListConstruct %float2.000000e00, %float2.000000e00 : (!torch.float, !torch.float) -> !torch.list + %1 = torch.aten.__interpolate.size_list_scale_list %arg0, %none, %0, %str, %false, %none, %false : !torch.vtensor<[1,16,135,240],f32>, !torch.none, !torch.list, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,16,270,480],f32> + return %1 : !torch.vtensor<[1,16,270,480],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tril$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],si32>) -> !torch.vtensor<[2,4],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],si32> -> tensor<2x4xi32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<{{\[\[}}1, 1, 0, 0], [1, 1, 1, 0]]> : tensor<2x4xi32>}> : () -> tensor<2x4xi32> +// CHECK: %[[VAL_4:.*]] = tosa.mul %[[VAL_1]], %[[VAL_3]] {shift = 0 : i8} : (tensor<2x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x4xi32> -> !torch.vtensor<[2,4],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,4],si32> +// CHECK: } +func.func @torch.aten.tril$basic(%arg0: !torch.vtensor<[2,4], si32>) -> !torch.vtensor<[2,4], si32> { + %int0 = torch.constant.int 1 + %0 = torch.aten.tril %arg0, %int0 : !torch.vtensor<[2,4],si32>, !torch.int -> !torch.vtensor<[2,4],si32> + return %0 : !torch.vtensor<[2,4],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.min.dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { +// CHECK-DAG: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> +// CHECK-DAG: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK-DAG: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[VAL_5:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK-DAG: %[[VAL_6:.*]] = tosa.negate %[[VAL_2]] : (tensor<3x2x3xf32>) -> tensor<3x2x3xf32> +// CHECK-DAG: %[[VAL_7:.*]] = tosa.argmax %[[VAL_6]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2xi64> +// CHECK-DAG: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<3x2xi64>) -> tensor<3x2x1xi64> +// CHECK-DAG: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK-DAG: %[[VAL_10:.*]] = torch_c.to_builtin_tensor %[[VAL_9]] : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> +// CHECK: return %[[VAL_10]] : tensor<3x2x1xf32> +// CHECK: } +func.func @torch.aten.min.dim$basic(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x1xf32> { + %0 = torch_c.from_builtin_tensor %arg0 : tensor<3x2x3xf32> -> !torch.vtensor<[3,2,3],f32> + %true = torch.constant.bool true + %int2 = torch.constant.int 2 + %values, %indices = torch.aten.min.dim %0, %int2, %true : !torch.vtensor<[3,2,3],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],f32>, !torch.vtensor<[3,2,1],si64> + %1 = torch_c.to_builtin_tensor %values : !torch.vtensor<[3,2,1],f32> -> tensor<3x2x1xf32> + return %1 : tensor<3x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.min$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_min %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_min %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_min %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.min$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.aten.min %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_2:.*]] = tosa.reduce_max %[[VAL_1]] {axis = 0 : i32} : (tensor<3x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_3:.*]] = tosa.reduce_max %[[VAL_2]] {axis = 1 : i32} : (tensor<1x2x3xf32>) -> tensor<1x1x3xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reduce_max %[[VAL_3]] {axis = 2 : i32} : (tensor<1x1x3xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x1x1xf32>) -> tensor<1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.max$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[1],f32> { + %0 = torch.aten.max %arg0: !torch.vtensor<[3,2,3],f32> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.prod.dim_int$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],f32> -> tensor<3x2x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = torch.constant.none +// CHECK: %[[VAL_5:.*]] = tosa.reduce_prod %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xf32>) -> tensor<3x2x1xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x2x1xf32> -> !torch.vtensor<[3,2,1],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,2,1],f32> +// CHECK: } +func.func @torch.aten.prod.dim_int$basic(%arg0: !torch.vtensor<[3,2,3],f32>) -> !torch.vtensor<[3,2,1],f32> { + %dim = torch.constant.int 2 + %keepdims = torch.constant.bool true + %dtype = torch.constant.none + %0 = torch.aten.prod.dim_int %arg0, %dim, %keepdims, %dtype: !torch.vtensor<[3,2,3],f32> , !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.all.dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,2,3],i1> -> tensor<3x2x3xi1> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool true +// CHECK: %[[VAL_4:.*]] = tosa.reduce_all %[[VAL_1]] {axis = 2 : i32} : (tensor<3x2x3xi1>) -> tensor<3x2x1xi1> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x2x1xi1> -> !torch.vtensor<[3,2,1],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,2,1],i1> +// CHECK: } +func.func @torch.aten.all.dim$basic(%arg0: !torch.vtensor<[3,2,3],i1>) -> !torch.vtensor<[3,2,1],i1> { + %dim = torch.constant.int 2 + %keepdims = torch.constant.bool true + %0 = torch.aten.all.dim %arg0, %dim, %keepdims: !torch.vtensor<[3,2,3],i1> , !torch.int, !torch.bool -> !torch.vtensor<[3,2,1],i1> + return %0 : !torch.vtensor<[3,2,1],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_trunc( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc" +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.greater_equal %[[VAL_6]], %[[VAL_11]] : (tensor, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.select %[[VAL_13]], %[[VAL_10]], %[[VAL_12]] : (tensor, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor +// CHECK: %[[VAL_15:.*]] = tosa.abs %[[VAL_6]] : (tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.floor %[[VAL_15]] : (tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_16]], %[[VAL_14]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$float_trunc(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { + %str = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_trunc( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "trunc" +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$int_trunc(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { + %str = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64> + return %0 : !torch.vtensor<[?, ?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_floor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor" +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.floor %[[VAL_6]] : (tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$float_floor(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { + %str = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_floor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "floor" +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor) -> tensor<1x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_5]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_13:.*]] = tosa.greater %[[VAL_11]], %[[VAL_12]] : (tensor<1x1xi32>, tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.mul %[[VAL_7]], %[[VAL_6]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_15:.*]] = tosa.equal %[[VAL_14]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_16:.*]] = tosa.logical_not %[[VAL_15]] : (tensor) -> tensor +// CHECK: %[[VAL_17:.*]] = tosa.sub %[[VAL_7]], %[[VAL_10]] : (tensor, tensor<1x1xi32>) -> tensor +// CHECK: %[[VAL_18:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_16]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_19:.*]] = tosa.select %[[VAL_18]], %[[VAL_17]], %[[VAL_7]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_20:.*]] = tosa.cast %[[VAL_19]] : (tensor) -> tensor +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?],si64> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$int_floor(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { + %str = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64> + return %0 : !torch.vtensor<[?, ?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$float_basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "" +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$float_basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch.vtensor<[?, ?],f32>) -> !torch.vtensor<[?, ?],f32> { + %str = torch.constant.str "" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],f32>, !torch.vtensor<[?, ?],f32>, !torch.str -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$int_basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %[[VAL_4:.*]] = torch.constant.str "" +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_6]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[?,?],si64> +// CHECK: } +func.func @torch.aten.div.Tensor_mode$int_basic(%arg0: !torch.vtensor<[?, ?],si64>, %arg1: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],si64> { + %str = torch.constant.str "" + %0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?, ?],si64>, !torch.vtensor<[?, ?],si64>, !torch.str -> !torch.vtensor<[?, ?],si64> + return %0 : !torch.vtensor<[?, ?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.ge.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.ge.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.ge.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.remainder.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_2]], %[[VAL_6]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_3]], %[[VAL_7]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[2,4],f32> +// CHECK: } +func.func @torch.aten.remainder.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { + %0 = torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> + return %0 : !torch.vtensor<[2, 4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fmod.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_2]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_4]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_10]] : (tensor<2x4xf32>, tensor<1x1xf32>) -> tensor<2x4xi1> +// CHECK: %[[VAL_13:.*]] = tosa.select %[[VAL_12]], %[[VAL_9]], %[[VAL_11]] : (tensor<2x4xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_14:.*]] = tosa.abs %[[VAL_5]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_15:.*]] = tosa.floor %[[VAL_14]] : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_15]], %[[VAL_13]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_2]], %[[VAL_16]] {shift = 0 : i8} : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_18:.*]] = tosa.sub %[[VAL_3]], %[[VAL_17]] : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[2,4],f32> +// CHECK: } +func.func @torch.aten.fmod.Tensor(%arg0: !torch.vtensor<[2, 4],f32>, %arg1: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> { + %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[2, 4],f32>, !torch.vtensor<[2, 4],f32> -> !torch.vtensor<[2, 4],f32> + return %0 : !torch.vtensor<[2, 4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_not( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_2:.*]] = tosa.logical_not %[[VAL_1]] : (tensor<4x5xi1>) -> tensor<4x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[4,5],i1> +// CHECK: } +func.func @torch.aten.logical_not(%arg0: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { + %0 = torch.aten.logical_not %arg0 : !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1> + return %0 : !torch.vtensor<[4,5],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cos( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.cos(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.cos %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sin( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.sin(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.sin %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_4]], %[[VAL_1]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.pow.Scalar(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %float2.000000e00 = torch.constant.float 2.000000e+00 + %0 = torch.aten.pow.Scalar %float2.000000e00, %arg0 : !torch.float, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.pow %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.pow.Tensor_Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.erf$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = tosa.erf %[[VAL_1]] : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.erf %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_and.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x1xi64>) -> tensor<1x1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_5]] : (tensor, tensor<1x1xi32>) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.bitwise_and.Scalar %arg0, %int2 : !torch.vtensor<[?,?],si32>, !torch.int -> !torch.vtensor<[?,?],si32> + return %0 : !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.le.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.le.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.le.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x1xi64>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_1]] : (tensor<1x1xf32>, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %int2 = torch.constant.int 2 + %0 = torch.aten.le.Scalar %arg0, %int2 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_xor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_xor %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_left_shift.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_left_shift %[[VAL_3]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_left_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_left_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0: !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.bitwise_right_shift.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.arithmetic_right_shift %[[VAL_3]], %[[VAL_2]] {round = false} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> +// CHECK: } +func.func @torch.aten.bitwise_right_shift.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { + %0 = torch.aten.bitwise_right_shift.Tensor %arg0, %arg1: !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32> + return %0: !torch.vtensor<[?,?],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.diagonal$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5,6],si32>) -> !torch.vtensor<[5,6,2],si32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5,6],si32> -> tensor<3x4x5x6xi32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int -2 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<[2, 3, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_6:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_5]] : (tensor<3x4x5x6xi32>, tensor<4xi32>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0, 0, 0], [0, 0, 0], [1, 0, 0], [0, 1, 0]]]]> : tensor<1x1x4x3xi32>}> : () -> tensor<1x1x4x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_6]], %[[VAL_7]] {shift = 0 : i8} : (tensor<5x6x4x3xi32>, tensor<1x1x4x3xi32>) -> tensor<5x6x4x3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<5x6x4x3xi32>) -> tensor<5x6x2x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reduce_sum %[[VAL_9]] {axis = 3 : i32} : (tensor<5x6x2x3xi32>) -> tensor<5x6x2x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<5x6x2x1xi32>) -> tensor<5x6x2xi32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<5x6x2xi32> -> !torch.vtensor<[5,6,2],si32> +// CHECK: return %[[VAL_12]] : !torch.vtensor<[5,6,2],si32> +// CHECK: } +func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> !torch.vtensor<[5,6,2], si32> { + %dim1 = torch.constant.int 1 + %dim2 = torch.constant.int 0 + %offset = torch.constant.int -2 + %0 = torch.aten.diagonal %arg0, %offset, %dim1, %dim2 : !torch.vtensor<[3,4,5,6],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[5,6,2],si32> + return %0 : !torch.vtensor<[5,6,2],si32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index_select( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2],si64> -> tensor<2xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],f32> -> tensor<4x5x6xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.tile %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x2xi32>, !tosa.shape<3>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<1x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[4,5,2],f32> +// CHECK: } +func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.index_select %arg0, %int2, %arg1 : !torch.vtensor<[4,5,6],f32>, !torch.int, !torch.vtensor<[2],si64> -> !torch.vtensor<[4,5,2],f32> + return %0 : !torch.vtensor<[4,5,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x12x128x128xf32>}> : () -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> !torch.vtensor<[1,12,128,128],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.fill.Scalar %arg0, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fill.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 12, 128, 128]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x1x1xi32>, !tosa.shape<4>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { + %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.flip( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_1]] {axis = 1 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.flip(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.flip %arg0, %0 : !torch.vtensor<[3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + return %1 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.round( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.floor %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_1]], %[[VAL_6]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_8:.*]] = tosa.ceil %[[VAL_1]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_6]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_10:.*]] = tosa.floor %[[VAL_9]] : (tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_6]], %[[VAL_11]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_13:.*]] = tosa.equal %[[VAL_7]], %[[VAL_4]] : (tensor<3x4x5xf32>, tensor<1x1x1xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_4]], %[[VAL_7]] : (tensor<1x1x1xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_15:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_12]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_16:.*]] = tosa.logical_or %[[VAL_14]], %[[VAL_15]] : (tensor<3x4x5xi1>, tensor<3x4x5xi1>) -> tensor<3x4x5xi1> +// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_6]], %[[VAL_8]] : (tensor<3x4x5xi1>, tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.round %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false= torch.constant.bool false + %count_include_pad = torch.constant.bool true + %divisor_override = torch.constant.none + + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32> + return %3 : !torch.vtensor<[1,192,35,35],f32> +} + +// ----- + +func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false= torch.constant.bool false + %count_include_pad = torch.constant.bool false + %divisor_override = torch.constant.int 9 + + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32> + return %3 : !torch.vtensor<[1,192,35,35],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { +// CHECK: %[[VAL_0:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_1:.*]] = torch.constant.bool false +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.constant.device "cpu" +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32> +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<3x4xi32>) -> tensor<3x4xi64> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<3x4xi64>}> : () -> tensor<3x4xi64> +// CHECK: %[[VAL_10:.*]] = tosa.cast %[[VAL_9]] : (tensor<3x4xi64>) -> tensor<3x4xi64> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xi64> -> !torch.vtensor<[3,4],si64> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],si64> +// CHECK: } +func.func @torch.aten.empty.memory_format$basic() -> !torch.vtensor<[3,4],si64> { + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %none = torch.constant.none + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list + %cpu = torch.constant.device "cpu" + %1 = torch.aten.empty.memory_format %0, %int4, %none, %cpu, %false, %none : !torch.list, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[3,4],si64> + %2 = torch.aten.fill.Scalar %1, %int0 : !torch.vtensor<[3,4],si64>, !torch.int -> !torch.vtensor<[3,4],si64> + return %2 : !torch.vtensor<[3,4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.scatter.src$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[10,8,6],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,4,3],si64>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[3,4,3],f32> -> tensor<3x4x3xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[2,4,3],si64> -> tensor<2x4x3xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[10,8,6],f32> -> tensor<10x8x6xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_4]] : (tensor<2x4x3xi64>) -> tensor<2x4x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x4x3xi32>) -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]], {{\[\[}}0], [0], [0]]], {{\[\[}}[1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]], {{\[\[}}1], [1], [1]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]], {{\[\[}}[0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]], {{\[\[}}0], [1], [2]]]]> : tensor<2x4x3x1xi32>}> : () -> tensor<2x4x3x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_8]], %[[VAL_10]] {axis = 3 : i32} : (tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>, tensor<2x4x3x1xi32>) -> tensor<2x4x3x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3x4x3xf32>) -> tensor<1x36x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<10x8x6xf32>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<2x4x3x3xi32>) -> tensor<24x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[48, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<24x3xi32>, tensor<1x3xi32>) -> tensor<24x3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<24x3xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_20:.*]] = tosa.scatter %[[VAL_13]], %[[VAL_19]], %[[VAL_12]] : (tensor<1x480x1xf32>, tensor<1x24xi32>, tensor<1x36x1xf32>) -> tensor<1x480x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x480x1xf32>) -> tensor<10x8x6xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<10x8x6xf32> -> !torch.vtensor<[10,8,6],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[10,8,6],f32> +// CHECK: } +func.func @torch.aten.scatter.src$basic(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[3,4,3],f32>) -> !torch.vtensor<[10,8,6],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[10,8,6],f32>, !torch.int, !torch.vtensor<[2,4,3],si64>, !torch.vtensor<[3,4,3],f32> -> !torch.vtensor<[10,8,6],f32> + return %0 : !torch.vtensor<[10,8,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.slice_scatter$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,8],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[6,1],f32> -> tensor<6x1xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,8],f32> -> tensor<6x8xf32> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<6x1xi32>}> : () -> tensor<6x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<6x1x1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]], {{\[\[}}4]], {{\[\[}}5]]]> : tensor<6x1x1xi32>}> : () -> tensor<6x1x1xi32> +// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_8]], %[[VAL_7]] {axis = 2 : i32} : (tensor<6x1x1xi32>, tensor<6x1x1xi32>) -> tensor<6x1x2xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<6x1xf32>) -> tensor<1x6x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<6x8xf32>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<6x1x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[8, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] {shift = 0 : i8} : (tensor<6x2xi32>, tensor<1x2xi32>) -> tensor<6x2xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<6x2xi32>) -> tensor<6x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x1xi32>) -> tensor<1x6xi32> +// CHECK: %[[VAL_18:.*]] = tosa.scatter %[[VAL_11]], %[[VAL_17]], %[[VAL_10]] : (tensor<1x48x1xf32>, tensor<1x6xi32>, tensor<1x6x1xf32>) -> tensor<1x48x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x48x1xf32>) -> tensor<6x8xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<6x8xf32> -> !torch.vtensor<[6,8],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[6,8],f32> +// CHECK: } +func.func @torch.aten.slice_scatter$basic(%arg0: !torch.vtensor<[6,8],f32>, %arg1: !torch.vtensor<[6,1],f32>) -> !torch.vtensor<[6,8],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.slice_scatter %arg0, %arg1, %int1, %int0, %int1, %int1 : !torch.vtensor<[6,8],f32>, !torch.vtensor<[6,1],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[6,8],f32> + return %0 : !torch.vtensor<[6,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.diag_embed$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int -2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]], {{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]]> : tensor<2x3x4x1xi32>}> : () -> tensor<2x3x4x1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x3x4x1xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3x4x4xf32>}> : () -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2x3x4x1xi32>) -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]]], {{\[\[}}{{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]], {{\[\[}}0]]], {{\[\[}}[1]], {{\[\[}}1]], {{\[\[}}1]], {{\[\[}}1]]], {{\[\[}}[2]], {{\[\[}}2]], {{\[\[}}2]], {{\[\[}}2]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]], {{\[\[}}{{\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]], {{\[\[}}[0]], {{\[\[}}1]], {{\[\[}}2]], {{\[\[}}3]]]]]> : tensor<2x3x4x1x1xi32>}> : () -> tensor<2x3x4x1x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_8]] {axis = 4 : i32} : (tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>, tensor<2x3x4x1x1xi32>) -> tensor<2x3x4x1x4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<2x3x4x1xf32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<2x3x4x4xf32>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<2x3x4x1x4xi32>) -> tensor<24x4xi32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[48, 16, 4, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<4xi32>) -> tensor<1x4xi32> +// CHECK: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] {shift = 0 : i8} : (tensor<24x4xi32>, tensor<1x4xi32>) -> tensor<24x4xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reduce_sum %[[VAL_18]] {axis = 1 : i32} : (tensor<24x4xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_21:.*]] = tosa.scatter %[[VAL_14]], %[[VAL_20]], %[[VAL_13]] : (tensor<1x96x1xf32>, tensor<1x24xi32>, tensor<1x24x1xf32>) -> tensor<1x96x1xf32> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<1x96x1xf32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_23:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_24:.*]] = tosa.transpose %[[VAL_22]], %[[VAL_23]] : (tensor<2x3x4x4xf32>, tensor<4xi32>) -> tensor<2x3x4x4xf32> +// CHECK: %[[VAL_25:.*]] = torch_c.from_builtin_tensor %[[VAL_24]] : tensor<2x3x4x4xf32> -> !torch.vtensor<[2,3,4,4],f32> +// CHECK: return %[[VAL_25]] : !torch.vtensor<[2,3,4,4],f32> +// CHECK: } +func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,3,4,4],f32> { + %int0 = torch.constant.int 0 + %int-2 = torch.constant.int -2 + %int-1 = torch.constant.int -1 + %0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32> + return %0 : !torch.vtensor<[2,3,4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4,2],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64> +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]] : (!torch.vtensor<[],si64>) -> !torch.list +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],si64> -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_9:.*]] = tosa.greater %[[VAL_6]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = tosa.select %[[VAL_9]], %[[VAL_8]], %[[VAL_5]] : (tensor, tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor) -> tensor<1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<1x1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x1x8xi64>) -> tensor<4x2xi64> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[4,2],si64> +// CHECK: } +func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list + %1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.threshold_backward$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],si64> -> tensor<4xi64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_2]] : (tensor<4xi64>, tensor<4xi64>) -> tensor<4xi1> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor) -> tensor<1xi64> +// CHECK: %[[VAL_9:.*]] = tosa.select %[[VAL_7]], %[[VAL_8]], %[[VAL_3]] : (tensor<4xi1>, tensor<1xi64>, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<4xi64> -> !torch.vtensor<[4],si64> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[4],si64> +// CHECK: } +func.func @torch.aten.threshold_backward$basic(%arg0: !torch.vtensor<[4],si64>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.threshold_backward %arg0, %arg1, %int1 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.threshold$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],si64> -> tensor<4x5xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 5.000000e-01 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_1]], %[[VAL_4]] : (tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi1> +// CHECK: %[[VAL_7:.*]] = tosa.select %[[VAL_6]], %[[VAL_1]], %[[VAL_5]] : (tensor<4x5xi1>, tensor<4x5xi64>, tensor<1x1xi64>) -> tensor<4x5xi64> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5xi64> -> !torch.vtensor<[4,5],si64> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5],si64> +// CHECK: } +func.func @torch.aten.threshold$basic(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> { + %float5.000000e-01 = torch.constant.float 5.000000e-01 + %int2 = torch.constant.int 2 + %0 = torch.aten.threshold %arg0, %float5.000000e-01, %int2 : !torch.vtensor<[4,5],si64>, !torch.float, !torch.int -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logical_and$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5],i1> -> tensor<4x5xi1> +// CHECK: %[[VAL_4:.*]] = tosa.logical_and %[[VAL_3]], %[[VAL_2]] : (tensor<4x5xi1>, tensor<4x5xi1>) -> tensor<4x5xi1> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x5xi1> -> !torch.vtensor<[4,5],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,5],i1> +// CHECK: } +func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[4,5],i1>, %arg1: !torch.vtensor<[4,5],i1>) -> !torch.vtensor<[4,5],i1> { + %0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],i1> -> !torch.vtensor<[4,5],i1> + return %0 : !torch.vtensor<[4,5],i1> +} + +// ----- + +// CHECK-LABEL: torch.aten.uniform$basic +// CHECK: tosa.const +func.func @torch.aten.uniform$basic(%arg0: !torch.vtensor<[3,4],f64>) -> (!torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64>) { + %float1.000000e00 = torch.constant.float 1.000000e+00 + %float1.000000e01 = torch.constant.float 1.000000e+01 + %none = torch.constant.none + %0 = torch.aten.uniform %arg0, %float1.000000e00, %float1.000000e01, %none : !torch.vtensor<[3,4],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[3,4],f64> + return %0, %0 : !torch.vtensor<[3,4],f64>, !torch.vtensor<[3,4],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.as_strided$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<5x5xf32>) -> tensor<25xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 1, 2, 2, 3, 4, 4, 5, 6]> : tensor<9xi32>}> : () -> tensor<9xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array} : (tensor<9xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_10]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<25xf32>) -> tensor<1x25x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<9x1xi32>, tensor<1x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<9x1xi32>) -> tensor<9x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<9x1xi32>) -> tensor<1x9xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x25x1xf32>, tensor<1x9xi32>) -> tensor<1x9x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x9x1xf32>) -> tensor<9xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<9xf32>) -> tensor<3x3xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<3x3xf32> -> !torch.vtensor<[3,3],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[3,3],f32> +// CHECK: } +func.func @torch.aten.as_strided$basic(%arg0: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[3,3],f32> { + %none = torch.constant.none + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.aten.as_strided %arg0, %0, %1, %none : !torch.vtensor<[5,5],f32>, !torch.list, !torch.list, !torch.none -> !torch.vtensor<[3,3],f32> + return %2 : !torch.vtensor<[3,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.max_pool1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,112],f32> -> tensor<1x64x112xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_5]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_4]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x64x112xf32>) -> tensor<1x64x112x1xf32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_10]], %[[VAL_11]] : (tensor<1x64x112x1xf32>, tensor<4xi32>) -> tensor<1x112x1x64xf32> +// CHECK: %[[VAL_13:.*]] = tosa.max_pool2d %[[VAL_12]] {kernel = array, pad = array, stride = array} : (tensor<1x112x1x64xf32>) -> tensor<1x56x1x64xf32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_15:.*]] = tosa.transpose %[[VAL_13]], %[[VAL_14]] : (tensor<1x56x1x64xf32>, tensor<4xi32>) -> tensor<1x64x56x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x64x56x1xf32>) -> tensor<1x64x56xf32> +// CHECK: %[[VAL_17:.*]] = tensor.cast %[[VAL_16]] : tensor<1x64x56xf32> to tensor<1x64x56xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<1x64x56xf32> -> !torch.vtensor<[1,64,56],f32> +// CHECK: return %[[VAL_18]] : !torch.vtensor<[1,64,56],f32> +// CHECK: } +func.func @torch.aten.max_pool1d$basic(%arg0: !torch.vtensor<[1,64,112],f32>) -> !torch.vtensor<[1,64,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int2 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %4 = torch.aten.max_pool1d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56],f32> + return %4 : !torch.vtensor<[1,64,56],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.avg_pool1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x512x10xf32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_10:.*]] = tosa.transpose %[[VAL_8]], %[[VAL_9]] : (tensor<1x512x10x1xf32>, tensor<4xi32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_11:.*]] = tosa.avg_pool2d %[[VAL_10]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x10x1x512xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_11]], %[[VAL_12]] : (tensor<1x10x1x512xf32>, tensor<4xi32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<1x512x10x1xf32>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_15:.*]] = tensor.cast %[[VAL_14]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_16:.*]] = torch_c.from_builtin_tensor %[[VAL_15]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_16]] : !torch.vtensor<[1,512,10],f32> +// CHECK: } +func.func @torch.aten.avg_pool1d$basic(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %false : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.clamp.Tensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],f32> -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],f32> -> tensor<3x5xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_10:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_8]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_11:.*]] = tosa.minimum %[[VAL_10]], %[[VAL_9]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_12:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor}> : () -> tensor +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_14]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_17:.*]] = tosa.minimum %[[VAL_16]], %[[VAL_15]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_18:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_21:.*]] = tosa.maximum %[[VAL_5]], %[[VAL_19]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_22:.*]] = tosa.minimum %[[VAL_21]], %[[VAL_20]] : (tensor<3x5xf32>, tensor<1x1xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_12]], %[[VAL_18]], %[[VAL_23]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +// CHECK: } +func.func @torch.aten.clamp.Tensor$basic(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> (!torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>) { + %none = torch.constant.none + %0 = torch.aten.clamp.Tensor %arg0, %arg1, %none : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.none -> !torch.vtensor<[3,5],f32> + %1 = torch.aten.clamp.Tensor %arg0, %none, %arg2 : !torch.vtensor<[3,5],f32>, !torch.none, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32> + %2 = torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[3,5],f32> + return %0, %1, %2 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32>, !torch.vtensor<[3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.collapse$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<2x3x4xf32>) -> tensor<2x12xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x12xf32> -> !torch.vtensor<[2,12],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,12],f32> +// CHECK: } +func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[2,12],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32> + return %0 : !torch.vtensor<[2,12],f32> +} + +// ----- + +func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false = torch.constant.bool false + %count_include_pad = torch.constant.bool true + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32> +// CHECK: } +func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { + %int3 = torch.constant.int 3 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad1d %arg0, %0 : !torch.vtensor<[1,2,4],f32>, !torch.list -> !torch.vtensor<[1,2,8],f32> + return %1 : !torch.vtensor<[1,2,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.reflection_pad2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 10 +// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> +// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> +// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> +// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> +// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> +// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32> +// CHECK: } +func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { + %int10 = torch.constant.int 10 + %0 = torch.prim.ListConstruct %int10, %int10, %int10, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,20,20],f32>, !torch.list -> !torch.vtensor<[1,40,40],f32> + return %1 : !torch.vtensor<[1,40,40],f32> +} + + +// ----- +// CHECK-LABEL: func.func @torch.aten.reflection_pad3d$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[SLICE_L:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[REVERSE_L:.*]] = tosa.reverse %[[SLICE_L]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[SLICE_R:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[REVERSE_R:.*]] = tosa.reverse %[[SLICE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[CONCAT_LR:.*]] = tosa.concat %[[REVERSE_L]], %[[VAL_0]], %[[REVERSE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> +// CHECK: %[[SLICE_T:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[REVERSE_T:.*]] = tosa.reverse %[[SLICE_T]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[SLICE_B:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[REVERSE_B:.*]] = tosa.reverse %[[SLICE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[CONCAT_TB:.*]] = tosa.concat %[[REVERSE_T]], %[[CONCAT_LR]], %[[REVERSE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> +// CHECK: %[[SLICE_F:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[REVERSE_F:.*]] = tosa.reverse %[[SLICE_F]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[SLICE_BACK:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[REVERSE_BACK:.*]] = tosa.reverse %[[SLICE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[CONCAT_FB:.*]] = tosa.concat %[[REVERSE_F]], %[[CONCAT_TB]], %[[REVERSE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONCAT_FB]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: return %[[RESULT]] +func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int2, %int2, %int2, %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad3d %arg0, %0 : !torch.vtensor<[4,5,7,3,4],f32>, !torch.list -> !torch.vtensor<[4,5,11,7,8],f32> + return %1 : !torch.vtensor<[4,5,11,7,8],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,3,3],f32> -> tensor<1x1x3x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> +// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> +// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> +// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32> +// CHECK: } +func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list -> !torch.vtensor<[1,1,10,6],f32> + return %1 : !torch.vtensor<[1,1,10,6],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.outer$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3xf32>) -> tensor<3x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_6:.*]] = tosa.tile %[[VAL_4]], %[[VAL_5]] : (tensor<3x1xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.tile %[[VAL_7]], %[[VAL_8]] : (tensor<1x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], %[[VAL_9]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.prims.split_dim$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,8,3,3],si64> -> tensor<1x8x3x3xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x8x3x3xi64>) -> tensor<1x2x4x3x3xi64> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x4x3x3xi64>) -> tensor<1x2x2x2x3x3xi64> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,2,2,3,3],si64> +// CHECK: } +func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.prims.split_dim %arg0, %int1, %int2 : !torch.vtensor<[1,8,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,4,3,3],si64> + %1 = torch.prims.split_dim %0, %int2, %int2 : !torch.vtensor<[1,2,4,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,3],si64> + return %1 : !torch.vtensor<[1,2,2,2,3,3],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,2,3],f64> -> tensor<1x1x2x3xf64> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 4.000000e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.float 3.000000e+00 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 8 +// CHECK: %[[VAL_5:.*]] = torch.constant.int 9 +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x2x3xf64>) -> tensor<1x1x6xf64> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<1x1x72xi32>) -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_14]], %[[VAL_16]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<1x3xi32>) -> tensor<72x3xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reduce_sum %[[VAL_17]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> +// CHECK: %[[VAL_20:.*]] = tosa.gather %[[VAL_13]], %[[VAL_19]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64> +// CHECK: %[[VAL_22:.*]] = tosa.reshape %[[VAL_21]] {new_shape = array} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64> +// CHECK: %[[VAL_23:.*]] = torch_c.from_builtin_tensor %[[VAL_22]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> +// CHECK: return %[[VAL_23]] : !torch.vtensor<[1,1,8,9],f64> +// CHECK: } +func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { + %float4.000000e00 = torch.constant.float 4.000000e+00 + %float3.000000e00 = torch.constant.float 3.000000e+00 + %int8 = torch.constant.int 8 + %int9 = torch.constant.int 9 + %0 = torch.prim.ListConstruct %int8, %int9 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.upsample_nearest2d %arg0, %0, %float4.000000e00, %float3.000000e00 : !torch.vtensor<[1,1,2,3],f64>, !torch.list, !torch.float, !torch.float -> !torch.vtensor<[1,1,8,9],f64> + return %1 : !torch.vtensor<[1,1,8,9],f64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d.vec$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,4,5],f32> -> tensor<1x1x4x5xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x4x5xf32>) -> tensor<1x1x20xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x14xi32>) -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<3xi32>) -> tensor<1x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_13]], %[[VAL_15]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<1x3xi32>) -> tensor<14x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_12]], %[[VAL_18]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32> +// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,2,7],f32> +// CHECK: } +func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { + %none = torch.constant.none + %int2 = torch.constant.int 2 + %int7 = torch.constant.int 7 + %0 = torch.prim.ListConstruct %int2, %int7 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.upsample_nearest2d.vec %arg0, %0, %none : !torch.vtensor<[1,1,4,5],f32>, !torch.list, !torch.none -> !torch.vtensor<[1,1,2,7],f32> + return %1 : !torch.vtensor<[1,1,2,7],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gelu$tanh( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,3],f32> -> tensor<5x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.str "tanh" +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.636619746> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_9:.*]] = tosa.pow %[[VAL_7]], %[[VAL_3]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6]], %[[VAL_10]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_1]], %[[VAL_11]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_9]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_14:.*]] = tosa.tanh %[[VAL_13]] : (tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_4]], %[[VAL_14]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_8]], %[[VAL_15]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> +// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x3xf32> -> !torch.vtensor<[5,3],f32> +// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,3],f32> +// CHECK: } +func.func @torch.aten.gelu$tanh(%arg0: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { + %str = torch.constant.str "tanh" + %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[5,3],f32>, !torch.str -> !torch.vtensor<[5,3],f32> + return %0 : !torch.vtensor<[5,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.exp$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.exp$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.exp %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log10$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log10$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_4]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.log10$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log1p$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_1]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.log1p$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log1p$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_2]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.log %[[VAL_5]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logit$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 9.9999999999999995E-8 +// CHECK: %[[VAL_3:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_3]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.log %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 + %0 = torch.aten.logit %arg0, %float9.999990e-08 : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.logit$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 9.9999999999999995E-8 +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_3]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_7:.*]] = tosa.sub %[[VAL_6]], %[[VAL_4]] : (tensor<1x1xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.reciprocal %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_4]], %[[VAL_8]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.log %[[VAL_9]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.logit$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 + %0 = torch.aten.logit %arg0, %float9.999990e-08 : !torch.vtensor<[3,4],si32>, !torch.float -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.log$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.log %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log2$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.log2$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.log2 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.erf$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.erf %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.erf$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.erf %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.lt.Scalar$intfloat( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.100000e+00 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.100000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor) -> tensor<1xf64> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> +// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_5]], %[[VAL_6]] : (tensor<1xf64>, tensor<4xf64>) -> tensor<4xi1> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4xi1> -> !torch.vtensor<[4],i1> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[4],i1> +// CHECK: } +func.func @torch.aten.lt.Scalar$intfloat(%arg0: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { + %float1.100000e00 = torch.constant.float 1.100000e+00 + %0 = torch.aten.lt.Scalar %arg0, %float1.100000e00 : !torch.vtensor<[4],si64>, !torch.float -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sigmoid$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],si32>) -> !torch.vtensor<[3,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],si32> -> tensor<3x5xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x5xi32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_3:.*]] = tosa.sigmoid %[[VAL_2]] : (tensor<3x5xf32>) -> tensor<3x5xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,5],f32> +// CHECK: } +func.func @torch.aten.sigmoid$int(%arg0: !torch.vtensor<[3,5],si32>) -> !torch.vtensor<[3,5],f32> { + %0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[3,5],si32> -> !torch.vtensor<[3,5],f32> + return %0 : !torch.vtensor<[3,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tan$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.tan$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tan$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.sin %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = tosa.cos %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.tan$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.tanh$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = tosa.tanh %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.tanh$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$intfloat( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],si32> -> tensor<3x4x5xi32> +// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<3x4x5xi32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_4]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4,5],f32> +// CHECK: } +func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { + %0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32> +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x4xi32>) -> tensor<6x4x1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_10]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x2xi32> +// CHECK: %[[VAL_12:.*]] = tosa.mul %[[VAL_9]], %[[VAL_11]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<1x2xi32>) -> tensor<24x2xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reduce_sum %[[VAL_12]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_13]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> +// CHECK: %[[VAL_15:.*]] = tosa.gather %[[VAL_8]], %[[VAL_14]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> +// CHECK: %[[VAL_18:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_19:.*]] = tosa.transpose %[[VAL_17]], %[[VAL_18]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[3,4,2],f32> +// CHECK: } +func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32> +// CHECK: } +func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.expm1$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_4:.*]] = tosa.exp %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.expm1$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.expm1$int( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor) -> tensor<1x1xf32> +// CHECK: %[[VAL_5:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> +// CHECK: } +func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { + %0 = torch.aten.expm1 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.constant_pad_nd$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,20,20,4,4],f32> -> tensor<1x1x20x20x4x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 0xFFF0000000000000 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<12xindex>} : () -> !tosa.shape<12> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x20x20x4x4xf32>, !tosa.shape<12>, tensor) -> tensor<1x1x20x20x4x5xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x1x20x20x4x5xf32> -> !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: } +func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { + %float-Inf = torch.constant.float 0xFFF0000000000000 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.constant_pad_nd %arg0, %0, %float-Inf : !torch.vtensor<[1,1,20,20,4,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,1,20,20,4,5],f32> + return %1 : !torch.vtensor<[1,1,20,20,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,10,20],f32> -> tensor<5x2x10x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<10x2x3x3xf32>}> : () -> tensor<10x2x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_12]] : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_4]], %[[VAL_12]] : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_15:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_15]], %[[VAL_16]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK: } +func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { + %false = torch.constant.bool false + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense_resource : tensor<10x2x3x3xf32>) : !torch.vtensor<[10,2,3,3],f32> + %none = torch.constant.none + %int1 = torch.constant.int 1 + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[5,2,10,20],f32>, !torch.vtensor<[10,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[5,10,14,24],f32> + return %5 : !torch.vtensor<[5,10,14,24],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$depthwise( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4,10,20],f32> -> tensor<5x4x10x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense_resource : tensor<4x1x3x3xf32>}> : () -> tensor<4x1x3x3xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_13]] : (tensor<5x4x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_15]] : (tensor<4x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<3x3x4x1xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_17]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_18]], %[[VAL_19]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK: } +func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { + %false = torch.constant.bool false + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense_resource : tensor<4x1x3x3xf32>) : !torch.vtensor<[4,1,3,3],f32> + %none = torch.constant.none + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int4 : !torch.vtensor<[5,4,10,20],f32>, !torch.vtensor<[4,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[5,4,5,10],f32> + return %5 : !torch.vtensor<[5,4,5,10],f32> +} + +// ----- diff --git a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir index 08ec48448bab..68ca4f0e08ac 100644 --- a/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir +++ b/test/Conversion/TorchToTosa/torch-backend-to-tosa-backend-pipeline.mlir @@ -117,13 +117,26 @@ func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si1 // ----- +// CHECK-LABEL: torch.aten.div.Scalar$int_input_fp_output +// CHECK-SAME: %[[VAL_0:.*]]: tensor +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<7.812500e-03> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor +func.func @torch.aten.div.Scalar$int_input_fp_output(%arg0: !torch.vtensor<[?, ?],si64>) -> !torch.vtensor<[?, ?],f32> { + %int128 = torch.constant.int 128 + %0 = torch.aten.div.Scalar %arg0, %int128 : !torch.vtensor<[?, ?],si64>, !torch.int -> !torch.vtensor<[?, ?],f32> + return %0 : !torch.vtensor<[?, ?],f32> +} + +// ----- + // CHECK-LABEL: torch.aten.pow.Tensor$mixed_type // CHECK-SAME: %[[VAL_0:.*]]: tensor -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.123400e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_0]] : (tensor) -> tensor +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> +// CHECK: %[[VAL_2:.*]] = tosa.cast %arg0 : (tensor) -> tensor // CHECK: %[[VAL_3:.*]] = tosa.pow %[[VAL_2]], %[[VAL_1]] : (tensor, tensor<1x1xf32>) -> tensor func.func @torch.aten.pow.Tensor$mixed_type(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f32> { - %fp0 = torch.constant.float 3.123400e+00 + %fp0 = torch.constant.float 3.000000e+00 %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f16>, !torch.float -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } diff --git a/test/Dialect/TMTensor/bufferize.mlir b/test/Dialect/TMTensor/bufferize.mlir index f36a2f521ad1..2d3a49c516ef 100644 --- a/test/Dialect/TMTensor/bufferize.mlir +++ b/test/Dialect/TMTensor/bufferize.mlir @@ -4,17 +4,17 @@ // CHECK-LABEL: func.func @scan_1d_inclusive( // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor) -> (tensor<128xi32>, tensor) { -// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> -// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> -// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref +// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : tensor<128xi32> to memref<128xi32> +// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> +// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref +// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> to tensor<128xi32> +// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref to tensor // CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref) { // CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32): // CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 // CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 // CHECK: } -// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> -// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor) -> (tensor<128xi32>, tensor) { %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(true) @@ -30,10 +30,12 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK-LABEL: func.func @scan_1d_exclusive( // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor) -> (tensor<128xi32>, tensor) { -// CHECK: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> -// CHECK: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref -// CHECK: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> -// CHECK: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref +// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : tensor<128xi32> to memref<128xi32> +// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : tensor to memref +// CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> +// CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref +// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> to tensor<128xi32> +// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref to tensor // CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref to memref // CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref) { @@ -41,8 +43,6 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK: %[[OUT_CURRENT_ELEMENT:.*]] = arith.addi %[[OUT_PREV_ELEMENT]], %[[IN_ELEMENT]] : i32 // CHECK: tm_tensor.yield %[[OUT_CURRENT_ELEMENT]] : i32 // CHECK: } -// CHECK: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> -// CHECK: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref // CHECK: return %[[OUT_TENSOR_NEW]], %[[ACC_TENSOR_NEW]] : tensor<128xi32>, tensor func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: tensor) -> (tensor<128xi32>, tensor) { %ret_out, %ret_acc = tm_tensor.scan dimension(0) inclusive(false) @@ -59,17 +59,17 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, // CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { -// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> -// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> -// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> -// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> +// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : tensor<3xi32> to memref<3xi32> +// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : tensor<3x1xi32> to memref<3x1xi32> +// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : tensor<8xi32> to memref<8xi32> +// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> +// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> to tensor<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 // CHECK: } -// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> func.func @scatter_update_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, @@ -87,10 +87,11 @@ func.func @scatter_update_scalar_1D( // CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, // CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { -// CHECK: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> -// CHECK: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> -// CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> -// CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> +// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : tensor<3xi32> to memref<3xi32> +// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : tensor<3x1xi32> to memref<3x1xi32> +// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : tensor<8xi32> to memref<8xi32> +// CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> +// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> to tensor<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { @@ -99,7 +100,6 @@ func.func @scatter_update_scalar_1D( // CHECK: %[[ADD:.*]] = arith.addi %[[ORIG_SCALAR]], %[[CST1]] : i32 // CHECK: tm_tensor.yield %[[ADD]] : i32 // CHECK: } -// CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> // CHECK: return %[[OUT_TENSOR]] : tensor<8xi32> func.func @scatter_add_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir index 5ee5bbf6f446..455a8e847486 100644 --- a/test/Dialect/Torch/adjust-calling-conventions.mlir +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -29,71 +29,71 @@ func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?], return %arg0 : !torch.tensor } -// CHECK-LABEL: func.func @none_return() { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: return -func.func @none_return() -> !torch.none { - %1 = torch.constant.none - return %1 : !torch.none -} +// COM: func.func @none_return() { +// COM: %[[NONE:.*]] = torch.constant.none +// COM: return +// func.func @none_return() -> !torch.none { +// %1 = torch.constant.none +// return %1 : !torch.none +// } -// CHECK-LABEL: func.func @none_call_return() { -// CHECK: call @none_return() : () -> () -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> () -// CHECK: return -func.func @none_call_return() { - %0 = call @none_return() : () -> !torch.none - "test.use"(%0) : (!torch.none) -> () - return -} +// COM: func.func @none_call_return() { +// COM: call @none_return() : () -> () +// COM: %[[NONE:.*]] = torch.constant.none +// COM: "test.use"(%[[NONE]]) : (!torch.none) -> () +// COM: return +// func.func @none_call_return() { +// %0 = call @none_return() : () -> !torch.none +// "test.use"(%0) : (!torch.none) -> () +// return +// } -// CHECK-LABEL: func.func @tuple_return( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : -// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, - %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { - %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple - return %1 : !torch.tuple -} +// COM: func.func @tuple_return( +// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : +// COM: !torch.tensor, !torch.tensor -> !torch.tuple +// COM: %[[CST0:.*]] = torch.constant.int 0 +// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: %[[CST1:.*]] = torch.constant.int 1 +// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +// func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, +// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { +// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple +// return %1 : !torch.tuple +// } -// CHECK-LABEL: func.func @call_tuple_return( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// CHECK: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// CHECK: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : -// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) -// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : -// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, - %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { - %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple - return %0 : !torch.tuple -} +// COM: func.func @call_tuple_return( +// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// COM: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// COM: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// COM: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// COM: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// COM: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : +// COM: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) +// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : +// COM: !torch.tensor, !torch.tensor -> !torch.tuple +// COM: %[[CST0:.*]] = torch.constant.int 0 +// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: %[[CST1:.*]] = torch.constant.int 1 +// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +// func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, +// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { +// %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple +// return %0 : !torch.tuple +// } diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index e7605f661698..d4afd67d65db 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -137,6 +137,46 @@ func.func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torc return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.eq.bool$same_value() -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.eq.bool$same_value() -> !torch.bool { + %a = torch.constant.bool false + %b = torch.constant.bool false + %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.eq.bool$different_value() -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: return %[[FALSE]] : !torch.bool +func.func @torch.aten.eq.bool$different_value() -> !torch.bool { + %a = torch.constant.bool true + %b = torch.constant.bool false + %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.eq.bool$same_operand( +// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { +// CHECK: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[TRUE]] : !torch.bool +func.func @torch.aten.eq.bool$same_operand(%arg0: !torch.bool) -> !torch.bool { + %0 = torch.aten.eq.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func.func @torch.aten.eq.bool$different_operand( +// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[RET:.*]] = torch.aten.eq.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool +// CHECK: return %[[RET]] : !torch.bool +func.func @torch.aten.eq.bool$different_operand(%a: !torch.bool) -> !torch.bool { + %b = torch.constant.bool false + %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.ne.bool() -> !torch.bool { // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool @@ -698,6 +738,20 @@ func.func @torch.aten.len.t$no_fold_list_mutated() -> !torch.int { return %2 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.left_t( +// CHECK: %[[C4:.*]] = torch.constant.int 4 +// CHECK: %[[C5:.*]] = torch.constant.int 5 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]], %[[C4]], %[[C5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: return %[[LIST]] : !torch.list +func.func @torch.aten.mul.left_t() -> !torch.list { + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.mul.left_t %0, %int2 : !torch.list, !torch.int -> !torch.list + return %1 : !torch.list +} + // CHECK-LABEL: func.func @torch.aten.__getitem__.t( // CHECK: %[[C5:.*]] = torch.constant.int 5 // CHECK: return %[[C5]] : !torch.int @@ -1168,6 +1222,29 @@ func.func @torch.aten.mul.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.int$canonicalize( +// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[CST30:.*]] = torch.constant.int 30 +// CHECK: %[[RET:.*]] = torch.aten.mul.int %[[ARG]], %[[CST30]] : !torch.int, !torch.int -> !torch.int +// CHECK: return %[[RET]] : !torch.int +func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int { + %cst6 = torch.constant.int 6 + %cst5 = torch.constant.int 5 + %1 = torch.aten.mul.int %arg0, %cst5: !torch.int, !torch.int -> !torch.int + %ret = torch.aten.mul.int %1, %cst6: !torch.int, !torch.int -> !torch.int + return %ret : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float { +// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 +// CHECK: return %[[CST6]] : !torch.float +func.func @torch.aten.mul.int_float() -> !torch.float { + %cst2 = torch.constant.int 2 + %cst3 = torch.constant.float 3.0 + %ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { // CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 // CHECK: return %[[CST30]] : !torch.float @@ -1178,6 +1255,16 @@ func.func @torch.aten.mul.float() -> !torch.float { return %ret : !torch.float } +// CHECK-LABEL: func.func @torch.aten.mul.float_int() -> !torch.float { +// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 +// CHECK: return %[[CST6]] : !torch.float +func.func @torch.aten.mul.float_int() -> !torch.float { + %cst2 = torch.constant.float 2.0 + %cst3 = torch.constant.int 3 + %ret = torch.aten.mul.float_int %cst2, %cst3: !torch.float, !torch.int -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float { // CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00 // CHECK: return %[[CST_6]] : !torch.float @@ -1207,6 +1294,16 @@ func.func @torch.aten.floordiv.int() -> !torch.int { return %ret : !torch.int } +// CHECK-LABEL: func.func @torch.aten.floordiv.int$canonicalize( +// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int { +// CHECK: return %[[ARG]] : !torch.int +func.func @torch.aten.floordiv.int$canonicalize(%arg0: !torch.int) -> !torch.int { + %cst6 = torch.constant.int 6 + %1 = torch.aten.mul.int %arg0, %cst6: !torch.int, !torch.int -> !torch.int + %ret = torch.aten.floordiv.int %1, %cst6: !torch.int, !torch.int -> !torch.int + return %ret : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.remainder.int() -> !torch.int { // CHECK: %[[CST3:.*]] = torch.constant.int 3 // CHECK: return %[[CST3]] : !torch.int @@ -1507,20 +1604,67 @@ func.func @torch.aten.Float.Tensor(%arg0: !torch.float) -> !torch.float { } // CHECK-LABEL: func.func @torch.aten.squeeze$zero_rank( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { -// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32> -func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { - %0 = torch.aten.squeeze %arg0 : !torch.tensor<[],f32> -> !torch.tensor<[],f32> - return %0 : !torch.tensor<[],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32> +func.func @torch.aten.squeeze$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { + %0 = torch.aten.squeeze %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> } // CHECK-LABEL: func.func @torch.aten.squeeze.dim$zero_rank( -// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { -// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[],f32> -func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[],f32> +func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { %int0 = torch.constant.int 0 - %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.tensor<[],f32>, !torch.int -> !torch.tensor<[],f32> - return %0 : !torch.tensor<[],f32> + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> { +// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[127, 128]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> +// CHECK-NEXT: return %[[CST]] +func.func @torch.aten.squeeze.dim$cst() -> !torch.vtensor<[2],si64> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[[127], [128]]> : tensor<2x1xsi64>) : !torch.vtensor<[2,1],si64> + %1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64> + return %1 : !torch.vtensor<[2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> { +// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, true]> : tensor<3xi1>) : !torch.vtensor<[3],i1> +// CHECK-NEXT: return %[[CST]] +func.func @torch.aten.squeeze.dim$cst_i1() -> !torch.vtensor<[3],i1> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[[true], [false], [true]]> : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1> + %1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1> + return %1 : !torch.vtensor<[3],i1> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> { +// CHECK-NEXT: %[[CST:.+]] = torch.vtensor.literal(dense : tensor<3xi1>) : !torch.vtensor<[3],i1> +// CHECK-NEXT: return %[[CST]] +func.func @torch.aten.squeeze.dim$cst_splat_i1() -> !torch.vtensor<[3],i1> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense : tensor<3x1xi1>) : !torch.vtensor<[3,1],i1> + %1 = torch.aten.squeeze.dim %0, %int1 : !torch.vtensor<[3,1],i1>, !torch.int -> !torch.vtensor<[3],i1> + return %1 : !torch.vtensor<[3],i1> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$same_shape( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> { +// CHECK-NEXT: return %[[ARG]] +func.func @torch.aten.squeeze.dim$same_shape(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2,1],si64> { + %int0 = torch.constant.int 0 + %0 = torch.aten.squeeze.dim %arg0, %int0 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2,1],si64> + return %0 : !torch.vtensor<[2,1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.squeeze.dim$not_fold +// CHECK: torch.aten.squeeze.dim +func.func @torch.aten.squeeze.dim$not_fold(%arg0: !torch.vtensor<[2,1],si64>) -> !torch.vtensor<[2],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.squeeze.dim %arg0, %int1 : !torch.vtensor<[2,1],si64>, !torch.int -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> } // CHECK-LABEL: func.func @torch.aten.tensor$one_elem( @@ -1534,6 +1678,16 @@ func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { return %67 : !torch.vtensor<[1],si64> } +// CHECK-LABEL: func.func @torch.aten.tensor$no_fold( +// CHECK: torch.aten.tensor %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor +func.func @torch.aten.tensor$no_fold(%arg0: !torch.tensor) -> (!torch.tensor) { + %none = torch.constant.none + %false = torch.constant.bool false + %1 = torch.aten.size %arg0 : !torch.tensor -> !torch.list + %2 = torch.aten.tensor %1, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.tensor + return %2 : !torch.tensor +} + // CHECK-LABEL: func.func @torch.aten.tensor.float( // CHECK-NEXT: torch.vtensor.literal(dense<1.000000e+01> : tensor) : !torch.vtensor<[],f32> func.func @torch.aten.tensor.float() -> !torch.vtensor<[],f32> { @@ -1602,6 +1756,82 @@ func.func @torch.aten.view$1D(%arg0: !torch.tensor<[?],f32>) -> !torch.tensor<[? return %1 : !torch.tensor<[?],f32> } +// CHECK-LABEL: func.func @torch.aten.view$fold_splat( +// CHECK: %[[SPLAT:.*]] = torch.vtensor.literal(dense<2> : tensor<2x4x1xsi64>) : !torch.vtensor<[2,4,1],si64> +// CHECK: return %[[SPLAT]] : !torch.vtensor<[2,4,1],si64> +func.func @torch.aten.view$fold_splat() -> !torch.vtensor<[2,4,1],si64> { + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<2> : tensor<8xsi64>) : !torch.vtensor<[8],si64> + %1 = torch.prim.ListConstruct %int2, %int4, %int1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,4,1],si64> + return %2 : !torch.vtensor<[2,4,1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.view$fold_literal( +// CHECK: %[[LITERAL:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [ +// CHECK-SAME: [0, 1], [2, 3], [4, 5], [6, 7]]]> : tensor<1x4x2xsi64>) : !torch.vtensor<[1,4,2],si64> +// CHECK: return %[[LITERAL]] : !torch.vtensor<[1,4,2],si64> +func.func @torch.aten.view$fold_literal() -> !torch.vtensor<[1,4,2],si64> { + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<[0,1,2,3,4,5,6,7]> : tensor<8xsi64>) : !torch.vtensor<[8],si64> + %1 = torch.prim.ListConstruct %int1, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %2 = torch.aten.view %0, %1 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[1,4,2],si64> + return %2 : !torch.vtensor<[1,4,2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_literal( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xsi64>) : !torch.vtensor<[2,4],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[2,4],si64> +func.func @torch.aten.transpose.int$fold_literal() -> !torch.vtensor<[2,4],si64> { + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0,1],[2,3],[4,5],[6,7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.transpose.int %0, %int-1, %int0 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4], si64> + return %1 : !torch.vtensor<[2,4],si64> +} + +// CHECK-LABEL: func.func @torch.aten.transpose.int$fold_noop( +// CHECK: return %arg0 : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.transpose.int$fold_noop(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int-1 = torch.constant.int -1 + %int3 = torch.constant.int 3 + %0 = torch.aten.transpose.int %arg0, %int-1, %int3 : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// CHECK-LABEL: func.func @torch.aten.slice.Tensor$flip_slice_fold( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [6, 7], [4, 5], [2, 3], [0, 1]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[4,2],si64> +func.func @torch.aten.slice.Tensor$flip_slice_fold() -> !torch.vtensor<[4,2],si64> { + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + return %1 : !torch.vtensor<[4,2],si64> +} + +// CHECK-LABEL: func.func @torch.aten.slice.Tensor$negative_two_stride_fold( +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[ +// CHECK-SAME: [6, 7], [2, 3]]> : tensor<2x2xsi64>) : !torch.vtensor<[2,2],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[2,2],si64> +func.func @torch.aten.slice.Tensor$negative_two_stride_fold() -> !torch.vtensor<[2,2],si64> { + %int-5 = torch.constant.int -5 + %int-1 = torch.constant.int -1 + %int-2 = torch.constant.int -2 + %int0 = torch.constant.int 0 + %0 = torch.vtensor.literal(dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int0, %int-1, %int-5, %int-2 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,2],si64> + return %1 : !torch.vtensor<[2,2],si64> +} + // CHECK-LABEL: func.func @torch.aten.div.float$fold_zero_dividend( // CHECK: %[[CST0:.*]] = torch.constant.float 0.000000e+00 // CHECK: return %[[CST0]] : !torch.float @@ -1833,6 +2063,18 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t return %1#0, %1#1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32> } +// CHECK-LABEL: func.func @prim.ListUnpack$fold_list_cast( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) { +// CHECK: %[[CAST0:.+]] = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32> +// CHECK: %[[CAST1:.+]] = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[2,3],f32> to !torch.vtensor<[?,?],f32> +// CHECK: return %[[CAST0]], %[[CAST1]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> +func.func @prim.ListUnpack$fold_list_cast(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) { + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list + %1:2 = torch.prim.ListUnpack %0 : !torch.list -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> + return %1#0, %1#1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> +} + // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> // CHECK: return %[[CST]] : !torch.vtensor<[],si64> @@ -1907,6 +2149,17 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], return %2 : !torch.vtensor<[],si64> } +// CHECK-LABEL: func.func @torch.aten.sub.Tensor$mixed_dtype() -> !torch.vtensor<[],f64> { +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<2.750000e+01> : tensor) : !torch.vtensor<[],f64> +// CEHCK: return %[[CST]] +func.func @torch.aten.sub.Tensor$mixed_dtype() -> !torch.vtensor<[],f64> { + %int1 = torch.constant.int 1 + %0 = torch.vtensor.literal(dense<28> : tensor) : !torch.vtensor<[],si64> + %1 = torch.vtensor.literal(dense<5.000000e-01> : tensor) : !torch.vtensor<[],f64> + %2 = torch.aten.sub.Tensor %0, %1, %int1 : !torch.vtensor<[],si64>, !torch.vtensor<[],f64>, !torch.int -> !torch.vtensor<[],f64> + return %2 : !torch.vtensor<[],f64> +} + // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> // CHECK: return %[[CST]] : !torch.vtensor<[],si64> @@ -2129,15 +2382,15 @@ func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> { // ----- -// CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice +// CHECK-LABEL: @torch.aten.slice.tensor$not_fold_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> -// CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> -func.func @torch.aten.slice.tensor$fold_full_domain_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32> { +// CHECK: torch.aten.slice.Tensor +func.func @torch.aten.slice.tensor$not_fold_slice(%arg0: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3],f32> { %int1 = torch.constant.int 1 %int-1 = torch.constant.int -1 %int0 = torch.constant.int 0 - %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], f32> - return %0 : !torch.vtensor<[4],f32> + %0 = torch.aten.slice.Tensor %arg0, %int0, %int0, %int-1, %int1 : !torch.vtensor<[4], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], f32> + return %0 : !torch.vtensor<[3],f32> } // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_slice @@ -2151,6 +2404,32 @@ func.func @torch.aten.slice.tensor$fold_full_slice(%arg0: !torch.vtensor<[?],f32 return %0 : !torch.vtensor<[?],f32> } +// CHECK-LABEL: @torch.aten.slice.tensor$fold_oob_start +// CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense<[0, 1, 2]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> +// CHECK: return %[[LIT]] : !torch.vtensor<[3],si64> +func.func @torch.aten.slice.tensor$fold_oob_start() -> !torch.vtensor<[3],si64> { + %0 = torch.vtensor.literal(dense<[0,1,2,3]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %int-10 = torch.constant.int -10 + %int0 = torch.constant.int 0 + %1 = torch.aten.slice.Tensor %0, %int0, %int-10, %int-1, %int1 : !torch.vtensor<[4], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3], si64> + return %1 : !torch.vtensor<[3],si64> +} + +// CHECK-LABEL: @torch.aten.slice.tensor$nofold_invalid_shape +// CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor +// CHECK: return %[[SLICE]] +func.func @torch.aten.slice.tensor$nofold_invalid_shape() -> !torch.vtensor<[4],si64> { + %0 = torch.vtensor.literal(dense<[0,1,2,3]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %int-10 = torch.constant.int -10 + %int0 = torch.constant.int 0 + %1 = torch.aten.slice.Tensor %0, %int0, %int-10, %int-1, %int1 : !torch.vtensor<[4], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4], si64> + return %1 : !torch.vtensor<[4],si64> +} + // CHECK-LABEL: @torch.aten.slice.tensor$no_fold_step // CHECK: torch.aten.slice.Tensor func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, %dim: !torch.int) -> !torch.vtensor<[?],f32> { @@ -2199,7 +2478,10 @@ func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) { } // ----- - +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +// CHECK: %[[CST0:.+]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +// CHECK: return %[[CST]], %[[CST0]] func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> %int0 = torch.constant.int 0 @@ -2214,6 +2496,18 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> !torch.vtensor<[4,1],si64> { +// CHECK{LITERAL}: %0 = torch.vtensor.literal(dense<[[28], [14], [7], [4]]> : tensor<4x1xsi64>) : !torch.vtensor<[4,1],si64> +// CHECK: return %0 +func.func @torch.aten.slice.tensor$fold_dim_0_non_contiguous() -> (!torch.vtensor<[4,1],si64>) { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %0 = torch.vtensor.literal(dense<[[28, 28], [14, 14], [7, 7], [4, 4]]> : tensor<4x2xsi64>) : !torch.vtensor<[4,2],si64> + %1 = torch.aten.slice.Tensor %0, %int1, %int1, %int2, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64> + return %1 : !torch.vtensor<[4,1],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { @@ -3015,3 +3309,63 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64> return %result0 : !torch.vtensor<[10,64,56,56],f32> } + +// ----- + +// CHECK-LABEL: @torch.aten.max_pool3d_with_indices$canonicalize( +// CHECK: %[[ARG:.*]]: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> { +// CHECK: %[[RET:.*]] = torch.aten.max_pool3d %[[ARG]] +// CHECK: return %[[RET]] : !torch.vtensor<[10,64,56,56,56],f32> +func.func @torch.aten.max_pool3d_with_indices$canonicalize(%arg0: !torch.vtensor<[10,64,112,112,112],f32>) -> !torch.vtensor<[10,64,56,56,56],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %29 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %30 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %31 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %result0, %result1 = torch.aten.max_pool3d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56,56],f32>, !torch.vtensor<[10,64,56,56,56],si64> + return %result0 : !torch.vtensor<[10,64,56,56,56],f32> +} + +// ----- + +// CHECK-LABEL: @torch.aten.clone$no_fold( +func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) { + // CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor + %none = torch.constant.none + %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor + %1 = torch.copy.to_tensor %0 : !torch.tensor + return %1 : !torch.tensor +} + +// ----- + +// CHECK-LABEL: @torch.symbolic_int$canonicalize( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int +// CHECK-NOT: %[[S1:.*]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int +// CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> +// CHECK: %[[V1:.*]] = torch.aten.slice.Tensor %[[ARG1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[V1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: %[[V2:.*]] = torch.aten.add.Tensor %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: torch.bind_symbolic_shape %[[V2]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +// CHECK: return %[[V2]] : !torch.vtensor<[?],f32> +func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int + %1 = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int + torch.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %arg1, [%0], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int9223372036854775807 = torch.constant.int 9223372036854775807 + %int1_0 = torch.constant.int 1 + %2 = torch.aten.slice.Tensor %arg1, %int0, %int1, %int9223372036854775807, %int1_0 : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + %int1_1 = torch.constant.int 1 + %3 = torch.aten.add.Tensor %arg0, %2, %int1_1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> + torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %3 : !torch.vtensor<[?],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops-illegal.mlir b/test/Dialect/Torch/decompose-complex-ops-illegal.mlir new file mode 100644 index 000000000000..773c0f5c3c30 --- /dev/null +++ b/test/Dialect/Torch/decompose-complex-ops-illegal.mlir @@ -0,0 +1,41 @@ +// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s + +func.func @torch.aten.pad.reflect(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "reflect" + // CHECK: torch.aten.pad %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} + +// ----- + +func.func @torch.aten.pad.edge(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "edge" + // CHECK: torch.aten.pad %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} + +// ----- + +func.func @torch.aten.pad.wrap(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "wrap" + // CHECK: torch.aten.pad %{{.*}} %{{.*}} %{{.*}} %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops-legal.mlir b/test/Dialect/Torch/decompose-complex-ops-legal.mlir index 9cf4c3e9babd..27a5b5647c94 100644 --- a/test/Dialect/Torch/decompose-complex-ops-legal.mlir +++ b/test/Dialect/Torch/decompose-complex-ops-legal.mlir @@ -8,3 +8,17 @@ func.func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torc %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<[2,3],f32>, !torch.int, !torch.none -> !torch.tensor<[2,3],f32> return %ret : !torch.tensor<[2,3],f32> } + +// ----- + +func.func @torch.aten.pad.constant(%input: !torch.tensor<[2],f32>, %pads: !torch.vtensor<[2],si64>) -> !torch.tensor<[4],f32> { + %int0 = torch.constant.int 0 + %float0.000000e00 = torch.constant.float 0.000000e+00 + %1 = torch.aten.select.int %pads, %int0, %int0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + %2 = torch.aten.item %1 : !torch.vtensor<[],si64> -> !torch.int + %pad = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list + %str = torch.constant.str "constant" + // CHECK: torch.aten.constant_pad_nd %{{.*}}, %{{.*}}, %{{.*}} : !torch.tensor<[2],f32>, !torch.list, !torch.float -> !torch.tensor<[4],f32> + %ret = torch.aten.pad %input, %pad, %str, %float0.000000e00 : !torch.tensor<[2],f32>, !torch.list, !torch.str, !torch.float -> !torch.tensor<[4],f32> + return %ret : !torch.tensor<[4],f32> +} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 530160f990ae..21cf9a7d908f 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -26,33 +26,17 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch } // ----- -// CHECK-LABEL: func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input( -// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { -// CHECK-DAG: %[[CST0:.*]] = torch.constant.int 0 -// CHECK-DAG: %[[CST2:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[CST3:.*]] = torch.constant.int 3 -// CHECK-DAG: %[[CST7:.*]] = torch.constant.int 7 -// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false -// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true -// CHECK-DAG: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[SELF]], %[[CST2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[DIM3:.*]] = torch.aten.size.int %[[SELF]], %[[CST3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int -// CHECK: %[[REMAINER1:.*]] = torch.aten.remainder.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[COND1:.*]] = torch.aten.eq.int %[[REMAINER1]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[COND1]], "unimplemented: only support cases input size is an integer multiple of output size" -// CHECK: %[[STRIDE1:.*]] = torch.aten.floordiv.int %[[DIM2]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[REMAINER2:.*]] = torch.aten.remainder.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[COND2:.*]] = torch.aten.eq.int %[[REMAINER2]], %[[CST0]] : !torch.int, !torch.int -> !torch.bool -// CHECK: torch.runtime.assert %[[COND2]], "unimplemented: only support cases input size is an integer multiple of output size" -// CHECK: %[[STRIDE2:.*]] = torch.aten.floordiv.int %[[DIM3]], %[[CST7]] : !torch.int, !torch.int -> !torch.int -// CHECK: %[[KERNEL_SIZE:.*]] = torch.prim.ListConstruct %[[STRIDE1]], %[[STRIDE2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[CST0]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[AVG_POOL:.*]] = torch.aten.avg_pool2d %[[SELF]], %[[KERNEL_SIZE]], %[[KERNEL_SIZE]], %[[PADDING]], %[[FALSE]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?,?],f32> -func.func @torch.aten.adaptive_avg_pool2d$output_size_divisible_by_input(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { - %int7 = torch.constant.int 7 - %output_size = torch.prim.ListConstruct %int7, %int7 : (!torch.int, !torch.int) -> !torch.list - %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?,?],f32> - return %0 : !torch.vtensor<[?,?,?,?],f32> + +// CHECK-LABEL: func.func @argmax_rank_1 +// CHECK: %[[I0:.*]] = torch.constant.int 0 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64> +// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64> +func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> { + %none = torch.constant.none + %false = torch.constant.bool false + %7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64> + return %7 : !torch.vtensor<[],si64> } // ----- @@ -78,3 +62,221 @@ func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch %0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16> return %0 : !torch.tensor<[?], f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.one_hot$fold( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> { +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT4:.*]] = torch.constant.int 4 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[ARANGE:.*]] = torch.aten.arange.start_step %[[INT0]], %arg1, %[[INT1]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64> +// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[ARG_0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,1],si64> +// CHECK: %[[EQ:.*]] = torch.aten.eq.Tensor %[[UNSQUEEZE]], %[[ARANGE]] : !torch.vtensor<[3,1],si64>, !torch.vtensor<[?],si64> -> !torch.vtensor<[3,?],i1> +// CHECK: %[[RESULT:.*]] = torch.aten.to.dtype %[[EQ]], %[[INT4]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,?],si64> +// CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,?],si64> +func.func @torch.aten.one_hot$fold(%arg0: !torch.vtensor<[3],si64>, %arg1: !torch.int) -> !torch.vtensor<[3,?],si64> { + %0 = torch.aten.one_hot %arg0, %arg1 : !torch.vtensor<[3],si64>, !torch.int -> !torch.vtensor<[3,?],si64> + return %0 : !torch.vtensor<[3,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[1],f32>, +// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[1],si32>, %[[ARG_3:.*]]: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[CONST1:.*]] = torch.constant.int 127 +// CHECK: %[[CONST2:.*]] = torch.constant.int -128 +// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],si32>, %arg3: !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> { + %int127 = torch.constant.int 127 + %int-128 = torch.constant.int -128 + %0:2 = torch.aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams %arg0, %arg1, %arg2, %arg3, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],si32>, !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1> + return %0#0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fake_quantize_per_channel_affine_cachemask( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?],f32>, +// CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK: %[[CONST0:.*]] = torch.constant.int 0 +// CHECK: %[[CONST1:.*]] = torch.constant.int 127 +// CHECK: %[[CONST2:.*]] = torch.constant.int -128 +// CHECK: %[[RESULT:.*]] = torch.aten.fake_quantize_per_channel_affine %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[CONST0]], %[[CONST2]], %[[CONST1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?,?],f32> +func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?],f32>, %arg2: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?,?,?,?],f32> { + %int0 = torch.constant.int 0 + %int127 = torch.constant.int 127 + %int-128 = torch.constant.int -128 + %0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1> + return %0#0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: test_einsum_inner_prod +func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { + // CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] + // CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]] + // CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] + // CHECK: %[[RHS_PERM:.+]] = torch.aten.permute %arg1, %[[RHS_LIST]] + // CHECK: %[[LHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]], %[[INT5]] + // CHECK: %[[LHS_VIEW:.+]] = torch.aten.view %[[LHS_PERM]], %[[LHS_SHP]] + // CHECK: %[[RHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT5]], %[[INT1]] + // CHECK: %[[RHS_VIEW:.+]] = torch.aten.view %[[RHS_PERM]], %[[RHS_SHP]] + // CHECK: %[[BMM:.+]] = torch.aten.bmm %[[LHS_VIEW]], %[[RHS_VIEW]] + // CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[OUT_VIEW:.+]] = torch.aten.view %[[BMM]], %[[EMPTY]] + // CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[OUT_PERM:.+]] = torch.aten.permute %[[OUT_VIEW]], %[[EMPTY]] + // CHECK: return %[[OUT_PERM]] + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>) -> !torch.list + %str = torch.constant.str "i,i" + %none_0 = torch.constant.none + %1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list, !torch.none -> !torch.vtensor<[],f64> + return %1 : !torch.vtensor<[],f64> +} + +// ----- + +// CHECK: func.func @torch.aten.fmod_int(%[[ARG0:.+]]: !torch.vtensor<[?],si32>, %[[ARG1:.+]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> { +// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 +// CHECK: %[[V0:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32> +// CHECK: %[[V1:.+]] = torch.aten.mul.Tensor %[[V0]], %[[ARG1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32> +// CHECK: %[[V2:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V1]], %[[FLOAT1]] : !torch.vtensor<[?],si32>, !torch.vtensor<[?],si32>, !torch.float -> !torch.vtensor<[?],si32> +// CHECK: return %[[V2]] : !torch.vtensor<[?],si32> +func.func @torch.aten.fmod_int(%arg0: !torch.vtensor<[?],si32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[?],si32> { + %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],si32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[?],si32> + return %0 : !torch.vtensor<[?],si32> +} + +// ----- + +// CHECK: func.func @torch.aten.fmod_float(%[[ARG0:.+]]: !torch.vtensor<[?],f16>, %[[ARG1:.+]]: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> { +// CHECK: %[[FLOAT1:.+]] = torch.constant.float 1.000000e+00 +// CHECK: %[[V0:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[V1:.+]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[INT5:.+]] = torch.constant.int 5 +// CHECK: %[[V2:.+]] = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> +// CHECK: %[[INT0:.+]] = torch.constant.int 0 +// CHECK: %[[V3:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V4:.+]] = torch.aten.gt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1> +// CHECK: %[[V5:.+]] = torch.aten.lt.Scalar %[[V3]], %[[INT0]] : !torch.vtensor<[?],f16>, !torch.int -> !torch.vtensor<[?],i1> +// CHECK: %[[V6:.+]] = torch.aten.to.dtype %[[V2]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> +// CHECK: %[[V7:.+]] = torch.aten.to.dtype %[[V1]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> +// CHECK: %[[V8:.+]] = torch.aten.where.self %[[V4]], %[[V6]], %[[V7]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V9:.+]] = torch.aten.to.dtype %[[V0]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f16> +// CHECK: %[[V10:.+]] = torch.aten.where.self %[[V5]], %[[V9]], %[[V8]] : !torch.vtensor<[?],i1>, !torch.vtensor<[],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V11:.+]] = torch.aten.abs %[[V3]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V12:.+]] = torch.aten.floor %[[V11]] : !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V13:.+]] = torch.aten.mul.Tensor %[[V10]], %[[V12]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V14:.+]] = torch.aten.mul.Tensor %[[V13]], %[[ARG1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> +// CHECK: %[[V15:.+]] = torch.aten.sub.Tensor %[[ARG0]], %[[V14]], %[[FLOAT1]] : !torch.vtensor<[?],f16>, !torch.vtensor<[?],f16>, !torch.float -> !torch.vtensor<[?],f16> +// CHECK: return %[[V15]] : !torch.vtensor<[?],f16> +func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.vtensor<[1],f16>) -> !torch.vtensor<[?],f16> { + %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> + return %0 : !torch.vtensor<[?],f16> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5 +// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16 +// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32> +// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32> +// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list -> !torch.vtensor<[16,5,2],f32> +// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> +// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> +func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { + %int-1 = torch.constant.int -1 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> + return %out : !torch.vtensor<[16,5],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( +// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19 +// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23 +// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32> +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> +// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32> +// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list -> !torch.vtensor<[23,19,2],f32> +// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> +// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> +// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> +func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> + return %out : !torch.vtensor<[19,23],complex> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.sym_constrain_range_for_size( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 7 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none +// CHECK: torch.aten.sym_constrain_range %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.int) -> !torch.int { + %int7 = torch.constant.int 7 + %int0 = torch.constant.int 0 + %none = torch.constant.none + torch.aten.sym_constrain_range_for_size %arg0, %none, %none : !torch.int, !torch.none, !torch.none + torch.aten.sym_constrain_range_for_size %arg0, %int0, %int7 : !torch.int, !torch.int, !torch.int + return %arg0 : !torch.int +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten._assert_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.int) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = torch.aten.ge.int %[[VAL_0]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_4:.*]] = torch.aten.Int.bool %[[VAL_3]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_5:.*]] = torch.aten.Bool.int %[[VAL_4]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_5]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" +// CHECK: %[[VAL_6:.*]] = torch.aten.gt.int %[[VAL_0]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_7:.*]] = torch.aten.Int.bool %[[VAL_6]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_8:.*]] = torch.aten.Bool.int %[[VAL_7]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_8]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" +// CHECK: return %[[VAL_0]] : !torch.int +// CHECK: } +func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { + %str = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" + %int2 = torch.constant.int 2 + %str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" + %int3 = torch.constant.int 3 + %0 = torch.aten.ge.int %arg0, %int3 : !torch.int, !torch.int -> !torch.bool + %1 = torch.aten.Int.bool %0 : !torch.bool -> !torch.int + torch.aten._assert_scalar %1, %str_0 : !torch.int, !torch.str + %2 = torch.aten.gt.int %arg0, %int2 : !torch.int, !torch.int -> !torch.bool + %3 = torch.aten.Int.bool %2 : !torch.bool -> !torch.int + torch.aten._assert_scalar %3, %str : !torch.int, !torch.str + return %arg0 : !torch.int +} diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index 594295d4e86d..cb39cbd53ece 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -82,6 +82,48 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch. // ----- +// CHECK-LABEL: func.func @mm_pad_commute +func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> { + // CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[int7:.*]] = torch.constant.int 7 + // CHECK-DAG: %[[none:.*]] = torch.constant.none + // CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02 + // CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02 + // CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00 + // CHECK-DAG: %[[str:.*]] = torch.constant.str "constant" + // CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[int0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[int1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[int2:.*]] = torch.constant.int 2 + // CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64> + // CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64> + // CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float + // CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8> + // CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8> + // CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8> + // CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32> + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %two = torch.constant.int 2 + %floatpad = torch.constant.float 3.5 + %zp = torch.constant.int -128 + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[8,8],!torch.qint8> -> !torch.vtensor<[8,8],f32> + %list = torch.prim.ListConstruct %one, %two, %zero, %one : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %str = torch.constant.str "constant" + %pad = torch.aten.pad %7, %list, %str, %floatpad : !torch.vtensor<[8,8],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[9,11],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[11,4],f32> + %16 = torch.aten.mm %pad, %13 : !torch.vtensor<[9,11],f32>, !torch.vtensor<[11,4],f32> -> !torch.vtensor<[9,4],f32> + return %16 : !torch.vtensor<[9,4],f32> +} + +// ----- + // CHECK-LABEL: @convolution_bias func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 63aa1e3755a9..8f38c66ad154 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -375,3 +375,30 @@ func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64] // expected-error @+1 {{invalid sparsity encoding attribute}} func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> + + +// ----- + +func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %0 = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int + // expected-error @+1 {{op requires equal number of shape symbol args and symbol args to the attached affine map, since they are 1:1 mapped}} + torch.bind_symbolic_shape %arg0, [], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} + +// ----- + +// Verifier should not fail here since the op does not require shapeSymbols. +func.func @torch.symbolic_int$no_shape_symbols_no_symbols_in_map(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + torch.bind_symbolic_shape %arg0, [], affine_map<()[] -> (1)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} + +// ----- + +func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { + %int0 = torch.constant.int 0 + // expected-error @+1 {{shape symbol must be produced by a SymbolicIntOp}} + torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> + return %arg0 : !torch.vtensor<[?],f32> +} diff --git a/test/Dialect/Torch/match-quantized-customs-ops.mlir b/test/Dialect/Torch/match-quantized-customs-ops.mlir index 4196e688157f..1dc89a639335 100644 --- a/test/Dialect/Torch/match-quantized-customs-ops.mlir +++ b/test/Dialect/Torch/match-quantized-customs-ops.mlir @@ -40,3 +40,24 @@ func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch %13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32> return %13 : !torch.vtensor<[1,3,8,8],f32> } + +// ----- + +// CHECK-LABEL: func.func @dequantize_per_channel +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[32,3,8,8],si8>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[32],f32>, +// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[32],si8>) -> !torch.vtensor<[32,3,8,8],f32> { +func.func @dequantize_per_channel(%arg0: !torch.vtensor<[32,3,8,8],si8>, %arg1: !torch.vtensor<[32],f32>, %arg2: !torch.vtensor<[32],si8>) -> !torch.vtensor<[32,3,8,8],f32> { + %min = torch.constant.int -128 + %max = torch.constant.int 127 + %dtype = torch.constant.int 1 + %axis = torch.constant.int 0 + // CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127 + // CHECK-DAG: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[32,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[32,3,8,8],si8> + // CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_channel_quantized_tensor %[[CLAMP]], %[[ARG1]], %[[ARG2]], %[[AXIS]] : !torch.vtensor<[32,3,8,8],si8>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],si8>, !torch.int -> !torch.vtensor<[32,3,8,8],!torch.qint8> + // CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.self %[[QINT]] : !torch.vtensor<[32,3,8,8],!torch.qint8> -> !torch.vtensor<[32,3,8,8],f32> + %13 = torch.operator "torch.quantized_decomposed.dequantize_per_channel"(%arg0, %arg1, %arg2, %axis, %min, %max, %dtype) : (!torch.vtensor<[32,3,8,8],si8>, !torch.vtensor<[32],f32>, !torch.vtensor<[32],si8>, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[32,3,8,8],f32> + return %13 : !torch.vtensor<[32,3,8,8],f32> +} diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index ecf5e626fb1d..a47cbf83a318 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -93,6 +93,9 @@ func.func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int { // CHECK: %int-3 = torch.constant.int -3 %int-3 = torch.constant.int -3 +// CHECK: %int5 = torch.constant.int 5 {test = "value"} +%int5 = torch.constant.int 5 {test = "value"} + // CHECK: %float1.000000e00 = torch.constant.float 1.000000e+00 %float1.000000e00 = torch.constant.float 1.000000e+00 // CHECK: %float-1.000000e00 = torch.constant.float -1.000000e+00 @@ -171,6 +174,7 @@ func.func @number_type_subtypes(%arg0: !torch.tensor, %arg1: !torch.list, % func.func private @tensor_legal_dtype$torch.qint8() -> !torch.tensor<*,!torch.qint8> func.func private @tensor_legal_dtype$torch.quint8() -> !torch.tensor<*,!torch.quint8> +func.func private @tensor_legal_dtype$torch.qint16() -> !torch.tensor<*,!torch.qint16> func.func @prim_list_construct$valid_shape_subtype(%arg0: !torch.vtensor<[1,53,56,96],f16>, %arg1: !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> { %arg2 = "torch.prim.ListConstruct"(%arg0, %arg1) : (!torch.vtensor<[1,53,56,96],f16>, !torch.vtensor<[1,3,56,96],f16>) -> !torch.list> @@ -184,5 +188,20 @@ func.func @torch.permute$negative_index_valid (%arg0: !torch.vtensor<[1,2,3],f32 %int1 = torch.constant.int 1 %perm = torch.prim.ListConstruct %int0, %int1, %intm1 : (!torch.int, !torch.int, !torch.int) -> !torch.list %3 = torch.aten.permute %arg0, %perm : !torch.vtensor<[1,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,3],f32> - return %3 : !torch.vtensor<[1,2,3],f32> + return %3 : !torch.vtensor<[1,2,3],f32> +} + +// Check fake quantize ops +func.func @torch.aten.fake_quantize_per_channel_affine (%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],si32>) -> !torch.vtensor<[3,3],f32> { + %int0 = torch.constant.int 0 + %int255 = torch.constant.int 255 + %1 = torch.aten.fake_quantize_per_channel_affine %arg0, %arg1, %arg2, %int0, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32> + return %1 : !torch.vtensor<[3,3],f32> +} + +func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],si32>) -> !torch.vtensor<[3,3],f32> { + %int0 = torch.constant.int 0 + %int255 = torch.constant.int 255 + %1 = torch.aten.fake_quantize_per_tensor_affine.tensor_qparams %arg0, %arg1, %arg2, %int0, %int255 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int -> !torch.vtensor<[3,3],f32> + return %1 : !torch.vtensor<[3,3],f32> } diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index 94bec8aa2160..ee5d3ff3519e 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -193,7 +193,7 @@ func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[NONE0:.+]] = torch.constant.none // CHECK: %[[NONE1:.+]] = torch.constant.none -// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]] +// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]], %[[FALSE]] // CHECK: return %[[ATTEN]] func.func @scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[1,1,5,5],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { %float0.000000e00 = torch.constant.float 0.000000e+00 diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index db8d71576ca3..00975a2405be 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -12,7 +12,13 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[I5]], %[[SZ1]], %[[SZ2]] // CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] : !torch.vtensor<[3],si32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %literal1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %1 = torch.aten.index_select %0, %int0, %literal1: !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list return %0 : !torch.vtensor<[3],si32> } @@ -20,20 +26,88 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: @shape_as_tensor_dim func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> { - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] - // CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]] + // CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32> // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> %dim = torch.constant.int 0 %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> %select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %item = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list return %select : !torch.vtensor<[],si32> } +// ----- + +// CHECK-LABEL: @cast_int_int +func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64> + %int4 = torch.constant.int 4 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64> + %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list + return %select : !torch.vtensor<[],si64> +} + +// ----- + +// CHECK-LABEL: @cast_int_float +func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float + // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32> + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> + %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32> + %item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float + %item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int + %list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list + return %select : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: @cast_int_float_static +func.func @cast_int_float_static(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],f32> { + // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[FLOAT2:.*]] = torch.constant.float 2.000000e+00 + // CHECK: %[[FLOAT3:.*]] = torch.constant.float 3.000000e+00 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[FLOAT1:.*]], %[[FLOAT2:.*]], %[[FLOAT3:.*]] : (!torch.float, !torch.float, !torch.float) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],f32> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[3],f32> + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %none = torch.constant.none + %shape = torch.vtensor.literal(dense<[1,2,3]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> + %dim = torch.constant.int 0 + %idx0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %select0 = torch.aten.index_select %cast_shape, %dim, %idx0 : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],f32> + %item0 = torch.aten.item %select0 : !torch.vtensor<[],f32> -> !torch.float + %item_int0 = torch.aten.Int.Scalar %item0 : !torch.float -> !torch.int + %list = torch.prim.ListConstruct %item_int0 : (!torch.int) -> !torch.list + return %cast_shape : !torch.vtensor<[3],f32> +} // ----- @@ -47,9 +121,116 @@ func.func @shape_as_tensor_dim_item(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !tor %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> %select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> %out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list return %out : !torch.int } +// ----- + +// CHECK-LABEL: @literal_item +func.func @literal_item() -> !torch.int { + // CHECK: %int2 = torch.constant.int 2 + // CHECK: return %int2 : !torch.int + %shape = torch.vtensor.literal(dense<[1,2,3]> : tensor<3xsi32>) : !torch.vtensor<[3],si32> + %dim = torch.constant.int 0 + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> + %select = torch.aten.index_select %shape, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list + return %out : !torch.int +} + +// ----- + +// CHECK-LABEL: @arith_prop +func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[int0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[int12:.*]] = torch.constant.int 12 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[int12_1:.*]] = torch.constant.int 12 + // CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + // CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32> + %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %1 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %float0.000000e00 = torch.constant.float 0.000000e+00 + %int1 = torch.constant.int 1 + %2 = torch.vtensor.literal(dense<[12, 1]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2],si64> + %4 = torch.aten.div.Tensor %3, %2 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %5 = torch.aten.mul.Tensor %4, %2 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %6 = torch.aten.sub.Tensor %3, %5, %int1 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2],si64> + %7 = torch.aten.index_select %6, %int0, %1 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %8 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %9 = torch.aten.item %7 : !torch.vtensor<[],si64> -> !torch.int + %10 = torch.aten.item %8 : !torch.vtensor<[],si64> -> !torch.int + %11 = torch.prim.ListConstruct %10, %9 : (!torch.int, !torch.int) -> !torch.list + %12 = torch.aten.constant_pad_nd %arg0, %11, %float0.000000e00 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + return %12 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: @broadcast_prop +func.func @broadcast_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.int { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: return %[[SZE]] : !torch.int + %dim = torch.constant.int 0 + %size = torch.aten.size.int %arg0, %dim : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %shape = torch.prim.NumToTensor.Scalar %size : !torch.int -> !torch.vtensor<[],si32> + %int3 = torch.constant.int 3 + %idx = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si32> + %bcastlist = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %bcast = torch.aten.broadcast_to %shape, %bcastlist : !torch.vtensor<[],si32>, !torch.list -> !torch.vtensor<[3],si32> + %select = torch.aten.index_select %bcast, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list + return %out : !torch.int +} + +// ----- + +// CHECK-LABEL: @eq_int_fold +func.func @eq_int_fold(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],f32> { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[sze0:.*]] = torch.aten.size.int %arg0, %[[int0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[sze1:.*]] = torch.aten.size.int %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[mul:.*]] = torch.aten.mul.int %[[sze0]], %[[sze1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[gt0:.*]] = torch.aten.gt.int %[[sze0]], %[[int0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[gt0]], "Expected dim size > 0." + // CHECK: %[[gt1:.*]] = torch.aten.gt.int %[[sze1]], %[[int0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[gt1]], "Expected dim size > 0." + // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[mul]], %[[int1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[view:.*]] = torch.aten.view %arg0, %[[list]] : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: return %[[view:.*]] : !torch.vtensor<[?,1],f32> + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %2 = torch.aten.mul.int %0, %1 : !torch.int, !torch.int -> !torch.int + %3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool + %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int + %5 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],i1> + %6 = torch.prim.NumToTensor.Scalar %0 : !torch.int -> !torch.vtensor<[],si64> + %7 = torch.prim.NumToTensor.Scalar %2 : !torch.int -> !torch.vtensor<[],si64> + %8 = torch.aten.where.self %5, %6, %7 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %9 = torch.aten.item %8 : !torch.vtensor<[],si64> -> !torch.int + %10 = torch.prim.ListConstruct %9, %int1 : (!torch.int, !torch.int) -> !torch.list + %11 = torch.aten.view %arg0, %10 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + return %11 : !torch.vtensor<[?,1],f32> +} // ----- @@ -64,11 +245,473 @@ func.func @shape_as_tensor_slice(%arg0 : !torch.vtensor<[5,?,?,?],f32>) -> !torc // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[SZ1]], %[[SZ3]] // CHECK-DAG: %[[TENSOR:.+]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?,?],f32> -> !torch.vtensor<[4],si32> %dim = torch.constant.int 0 %start = torch.constant.int 1 %end = torch.constant.int 5 %step = torch.constant.int 2 %slice = torch.aten.slice.Tensor %shape, %dim, %start, %end, %step : !torch.vtensor<[4], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> + %select = torch.aten.index_select %slice, %dim, %idx : !torch.vtensor<[2],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %item = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list return %slice : !torch.vtensor<[2],si32> } + + +// ----- + +// CHECK-LABEL: @view_as_flatten_static +func.func @view_as_flatten_static(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,1024],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,1024],f32> + %int1024 = torch.constant.int 1024 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int1024 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,16,64],f32>, !torch.list -> !torch.vtensor<[?,?,1024],f32> + return %3 : !torch.vtensor<[?,?,1024],f32> +} + + +// ----- + +// CHECK-LABEL: @view_as_unflatten_static +func.func @view_as_unflatten_static(%arg0: !torch.vtensor<[?,?,1024],f32>) -> !torch.vtensor<[?,?,16,64],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[CST16:.*]] = torch.constant.int 16 + // CHECK-DAG: %[[CST64:.*]] = torch.constant.int 64 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[CST16]], %[[CST64]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FLAT:.*]] = torch.aten.unflatten.int %arg0, %[[TWO]], %[[LIST]] : !torch.vtensor<[?,?,1024],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,16,64],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,16,64],f32> + %int16 = torch.constant.int 16 + %int64 = torch.constant.int 64 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,1024],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int16, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,1024],f32>, !torch.list -> !torch.vtensor<[?,?,16,64],f32> + return %3 : !torch.vtensor<[?,?,16,64],f32> +} + + +// ----- + +// CHECK-LABEL: @view_as_flatten_dynamic +func.func @view_as_flatten_dynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[THREE:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[THREE]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?],f32> + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?,?,?],f32> + return %3 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: @view_as_flatten_mid +func.func @view_as_flatten_mid(%arg0: !torch.vtensor<[?,?,?,?,2,4],f32>) -> !torch.vtensor<[?,?,?,4],f32> { + // CHECK-DAG: %[[TWO:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[FOUR:.*]] = torch.constant.int 4 + // CHECK-DAG: %[[FLAT:.*]] = torch.aten.flatten.using_ints %arg0, %[[TWO]], %[[FOUR]] : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,4],f32> + // CHECK: return %[[FLAT]] : !torch.vtensor<[?,?,?,4],f32> + %int-1 = torch.constant.int -1 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %int4 = torch.constant.int 4 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.int -> !torch.int + %2 = torch.prim.ListConstruct %0, %1, %int-1, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?,?,?,2,4],f32>, !torch.list -> !torch.vtensor<[?,?,?,4],f32> + return %3 : !torch.vtensor<[?,?,?,4],f32> +} + + +// ----- + +// CHECK-LABEL: @unsqueeze_squeeze_combo +func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.int { + // CHECK: %int0 = torch.constant.int 0 + // CHECK: %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,16,64],f32>, !torch.int -> !torch.int + // CHECK: return %0 : !torch.int + %0 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %1 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<1024> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64> + %4 = torch.aten.index_select %3, %int0, %1 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %5 = torch.aten.squeeze.dim %4, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %6 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,16,64],f32> -> !torch.vtensor<[4],si64> + %7 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %8 = torch.aten.squeeze.dim %7, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %9 = torch.aten.unsqueeze %5, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %10 = torch.aten.unsqueeze %8, %int0 : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + %11 = torch.prim.ListConstruct %9, %10, %2 : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.list + %12 = torch.aten.cat %11, %int0 : !torch.list, !torch.int -> !torch.vtensor<[3],si64> + %13 = torch.aten.slice.Tensor %12, %int0, %int0, %int1, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int + %list = torch.prim.ListConstruct %14 : (!torch.int) -> !torch.list + return %14 : !torch.int +} + + +// ----- + +// CHECK-LABEL: @eq_tensor_and_where_self +func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1_0]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64> + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.prim.ListConstruct %3, %int1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.tensor %4, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> + %6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + %7 = torch.aten.where.self %6, %1, %5 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + %select = torch.aten.index_select %7, %int0, %idx : !torch.vtensor<[4],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @eq_tensor_from_tensor_and_literal +func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],i1> { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[int0_1:.*]] = torch.constant.int 0 + // CHECK: %[[int0_2:.*]] = torch.constant.int 0 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[int0]], %[[int1_0]], %[[int0_1]], %[[int0_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1> + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.prim.ListConstruct %3, %int-1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.tensor %4, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> + %6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + %select = torch.aten.index_select %6, %int0, %idx : !torch.vtensor<[4],i1>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],i1> + %item = torch.aten.item %select : !torch.vtensor<[],i1> -> !torch.int + %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list + return %6 : !torch.vtensor<[4],i1> +} + + + +// ----- + +// CHECK-LABEL: @squeeze_dim_full_fold +func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.list { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE]] : (!torch.int) -> !torch.list + // CHECK: return %[[LIST]] : !torch.list + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %none = torch.constant.none + %false = torch.constant.bool false + %51 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %55 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %56 = torch.aten.full %55, %51, %none, %none, %none, %false : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + %57 = torch.aten.squeeze.dim %56, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %58 = torch.aten.item %57 : !torch.vtensor<[],si64> -> !torch.int + %59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list + return %59 : !torch.list +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_view$prop( +func.func @pytorch_dynamic_pad_export_view$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I144:.*]] = torch.constant.int 144 + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[x0]], %[[I144]], %[[x1]], %[[x2]], %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %7 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_slice$prop( +func.func @pytorch_dynamic_pad_export_slice$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[I144:.*]] = torch.constant.int 144 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]], %[[x1]], %[[x2]], %[[x0]], %[[I144]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %8 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_transpose$prop( +func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[2,4],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[DIM3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[DIM1:.*]] = torch.constant.int 144 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[DIM2]], %[[DIM0]], %[[I0_2]], %[[I0_3]], %[[DIM3]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,4],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[2,4],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %9 : !torch.vtensor<[2,4],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_full( +func.func @pytorch_dynamic_pad_export_full(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.list { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: return %[[x1]] : !torch.list + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %16 : !torch.list +} + +// ----- + +// CHECK-LABEL: @transpose$prop_3d_0_1 +func.func @transpose$prop_3d_0_1(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3_3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE0_1]], %[[SIZE1_0]], %[[SIZE1_1]], %[[SIZE0_2]], %[[SIZE0_3]], %[[SIZE1_2]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64> + %0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %4 = torch.aten.cat %3, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,2,2],si64> + %7 = torch.aten.transpose.int %6, %int0, %int1 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64> + %8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list -> !torch.vtensor<[8],si64> + %10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int + %12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[2,2,2],si64> +} + +// ----- + +// CHECK-LABEL: @transpose$prop_3d_m1_0 +func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3_3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE1_0]], %[[SIZE0_2]], %[[SIZE1_2]], %[[SIZE0_1]], %[[SIZE1_1]], %[[SIZE0_3]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64> + %0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %4 = torch.aten.cat %3, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,2,2],si64> + %7 = torch.aten.transpose.int %6, %int-1, %int0 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64> + %8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list -> !torch.vtensor<[8],si64> + %10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int + %12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[2,2,2],si64> +} diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index b7e7cf17ba0e..af96e108efbd 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -152,6 +152,23 @@ func.func @fully_unroll_prim_loop$no_unroll(%arg0: !torch.vtensor, %arg1: !torch return %0 : !torch.vtensor } +// CHECK-LABEL: func.func @fully_unroll_prim_loop$outside_region( +// CHECK: %[[LOOP:.*]] = torch.prim.Loop +func.func @fully_unroll_prim_loop$outside_region(%arg0: !torch.vtensor, %arg1: !torch.list, %arg2: !torch.int) -> !torch.vtensor { + %true = torch.constant.bool true + %0 = torch.prim.Loop %arg2, %true, init(%arg0) { + ^bb0(%arg3: !torch.int, %arg4: !torch.vtensor): + %1 = torch.shape.calculate { + torch.shape.calculate.yield %arg4 : !torch.vtensor + } shapes { + torch.prim.Print(%arg3) : !torch.int + torch.shape.calculate.yield.shapes %arg1 : !torch.list + } : !torch.vtensor + torch.prim.Loop.condition %true, iter(%1 : !torch.vtensor) + } : (!torch.int, !torch.bool, !torch.vtensor) -> !torch.vtensor + return %0 : !torch.vtensor +} + // CHECK-LABEL: func.func @abstractly_interpret_list_ops$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.int, @@ -489,3 +506,42 @@ func.func @shape_calc_with_two_uses(%arg0: !torch.vtensor<[2],f32>) -> !torch.vt return %arg0 : !torch.vtensor<[2],f32> } + +// CHECK-LABEL: func.func @unflat_shape_partial_dyn +// CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768 +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4 +// CHECK : } shapes { +// CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int +// CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int +// CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list +// CHECK : } : !torch.vtensor<[?,?,4,768],f32> +func.func @unflat_shape_partial_dyn(%arg0: !torch.vtensor<[?,?,3072],f32>) -> !torch.vtensor<[?,?,4,?],f32> { + %int768 = torch.constant.int 768 + %int3072 = torch.constant.int 3072 + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int1 = torch.constant.int 1 + %none = torch.constant.none + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int4, %int-1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.shape.calculate { + %2 = torch.aten.unflatten.int %arg0, %int2, %0 : !torch.vtensor<[?,?,3072],f32>, !torch.int, !torch.list -> !torch.vtensor<[?,?,4,?],f32> + torch.shape.calculate.yield %2 : !torch.vtensor<[?,?,4,?],f32> + } shapes { + %2 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int + %3 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int + %4 = torch.prim.ListConstruct %2, %3, %int3072 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.prim.ListConstruct %int4, %int768 : (!torch.int, !torch.int) -> !torch.list + %6 = torch.aten.slice.t %4, %none, %int2, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list + %7 = torch.aten.add.t %6, %5 : !torch.list, !torch.list -> !torch.list + %8 = torch.aten.slice.t %4, %int3, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list + %9 = torch.aten.add.t %7, %8 : !torch.list, !torch.list -> !torch.list + torch.shape.calculate.yield.shapes %9 : !torch.list + } : !torch.vtensor<[?,?,4,?],f32> + return %1 : !torch.vtensor<[?,?,4,?],f32> +} diff --git a/test/Dialect/Torch/torch-nary-canonicalize.mlir b/test/Dialect/Torch/torch-nary-canonicalize.mlir index b0d22e35da9c..9fb5bac1f82f 100644 --- a/test/Dialect/Torch/torch-nary-canonicalize.mlir +++ b/test/Dialect/Torch/torch-nary-canonicalize.mlir @@ -141,3 +141,113 @@ func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> { %0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> return %0 : !torch.vtensor<[4],f32> } + +// ----- + +// CHECK-LABEL: @fold_aten_rsub_scalar_int +func.func @fold_aten_rsub_scalar_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<-4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_2 = torch.constant.int 2 + %cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_rsub_scalar_float +func.func @fold_aten_rsub_scalar_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<-4.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_2 = torch.constant.float 2.0 + %cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.rsub.Scalar %cst_3, %cst_2, %cst_2: !torch.vtensor<[4],f32>, !torch.float, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_remainder_scalar_int +func.func @fold_aten_remainder_scalar_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_2 = torch.constant.int 2 + %cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_remainder_scalar_float +func.func @fold_aten_remainder_scalar_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<1.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_2 = torch.constant.float 2.0 + %cst_3 = torch.vtensor.literal(dense<3.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.remainder.Scalar %cst_3, %cst_2 : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_int_tensor_int +func.func @fold_aten_int_tensor_int() -> !torch.int { + // CHECK: %int3 = torch.constant.int 3 + %cst_3 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> + %0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],si64> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: @fold_aten_int_tensor_bool +func.func @fold_aten_int_tensor_bool() -> !torch.int { + // CHECK: %int1 = torch.constant.int 1 + %cst_false = torch.vtensor.literal(dense : tensor) : !torch.vtensor<[],i1> + %0 = torch.aten.Int.Tensor %cst_false : !torch.vtensor<[],i1> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: @fold_aten_int_tensor_float +func.func @fold_aten_int_tensor_float() -> !torch.int { + // CHECK: %int3 = torch.constant.int 3 + %cst_3 = torch.vtensor.literal(dense<3.1> : tensor) : !torch.vtensor<[],f32> + %0 = torch.aten.Int.Tensor %cst_3 : !torch.vtensor<[],f32> -> !torch.int + return %0 : !torch.int +} + +// ----- + +// CHECK-LABEL: @fold_aten_div_tensor_mode_int +func.func @fold_aten_div_tensor_mode_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_2 = torch.vtensor.literal(dense<2> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %trunc = torch.constant.str "trunc" + %0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %trunc : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.str -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_div_tensor_mode_float +func.func @fold_aten_div_tensor_mode_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<3.000000e+00> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_8 = torch.vtensor.literal(dense<8.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_2 = torch.vtensor.literal(dense<2.1> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %floor = torch.constant.str "floor" + %0 = torch.aten.div.Tensor_mode %cst_8, %cst_2, %floor : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.str -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_div_tensor_mode_none +func.func @fold_aten_div_tensor_mode_none() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.66666675> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_8 = torch.vtensor.literal(dense<8> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_3 = torch.vtensor.literal(dense<3> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %none = torch.constant.none + %0 = torch.aten.div.Tensor_mode %cst_8, %cst_3, %none : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.none -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} diff --git a/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir new file mode 100644 index 000000000000..752398474ce7 --- /dev/null +++ b/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir @@ -0,0 +1,67 @@ +// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-onnx-to-torch-backend-pipeline{backend-legal-ops=aten.flatten.using_ints,aten.unflatten.int})' -split-input-file %s | FileCheck %s + +// CHECK-LABEL: func.func @test_reshape_negative_dim_decompose +func.func @test_reshape_negative_dim_decompose(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.view %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> + return %0 : !torch.vtensor<[2,6,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_triu_decompose +func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ZERO_TENSOR:.+]] = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: %[[INT5:.+]] = torch.constant.int 5 + // CHECK: %[[ARANGE:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT4]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[4],si64> + // CHECK: %[[ARANGE_0:.+]] = torch.aten.arange.start_step %[[INT0]], %[[INT5]], %[[INT1]], %[[INT4]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[5],si64> + // CHECK: %[[UNSQUEEZE:.+]] = torch.aten.unsqueeze %[[ARANGE]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4,1],si64> + // CHECK: %[[UNSQUEEZE_0:.+]] = torch.aten.unsqueeze %[[ARANGE_0]], %[[INT0]] : !torch.vtensor<[5],si64>, !torch.int -> !torch.vtensor<[1,5],si64> + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %[[UNSQUEEZE]], %[[INT0]], %[[INT1]] : !torch.vtensor<[4,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[4,1],si64> + // CHECK: %[[COND:.+]] = torch.aten.ge.Tensor %[[UNSQUEEZE_0]], %[[ADD]] : !torch.vtensor<[1,5],si64>, !torch.vtensor<[4,1],si64> -> !torch.vtensor<[4,5],i1> + // CHECK: %[[RESULT:.+]] = torch.aten.where.self %[[COND]], %arg0, %[[ZERO_TENSOR]] : !torch.vtensor<[4,5],i1>, !torch.vtensor<[4,5],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4,5],si64> + %0 = torch.operator "onnx.Trilu"(%arg0) : (!torch.vtensor<[4,5],si64>) -> !torch.vtensor<[4,5],si64> + return %0 : !torch.vtensor<[4,5],si64> +} + +// ----- + +module { +// CHECK-LABEL: func.func @test_scalarize + func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} { + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32> + %0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor} : () -> !torch.vtensor<[],si64> + %2 = torch.operator "onnx.Gather"(%0, %1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> + %3 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64> + %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor} : () -> !torch.vtensor<[],si64> + %5 = torch.operator "onnx.Gather"(%3, %4) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> + %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %7 = torch.operator "onnx.Unsqueeze"(%2, %6) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> + %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %9 = torch.operator "onnx.Unsqueeze"(%5, %8) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> + %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_3209> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + %11 = torch.operator "onnx.Concat"(%7, %9, %10) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3],si64> + %12 = torch.operator "onnx.Reshape"(%arg0, %11) : (!torch.vtensor<[?,?,16,64],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> + return %12 : !torch.vtensor<[?,?,?],f32> + } +} + +{-# + dialect_resources: { + builtin: { + __21: "0x080000000000000000000000", + __22: "0x080000000100000000000000", + _onnx__Concat_3209: "0x080000000004000000000000" + } + } +#-} diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index 57077a723ada..c77351831d2f 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -83,3 +83,13 @@ func.func @unable_to_convert_lone_tensor_load(%arg0: tensor) { "test.sink"(%0) : (!torch.vtensor<[],f32>) -> () return } + +// ----- + +// CHECK-LABEL: @extfTruncf +func.func @extfTruncf(%arg0: f32) -> f32 { + %f64 = arith.extf %arg0 : f32 to f64 + %f32 = arith.truncf %f64 : f64 to f32 + // CHECK: return %arg0 + return %f32 : f32 +} diff --git a/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-contract-check.mlir b/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-contract-check.mlir new file mode 100644 index 000000000000..33fbfcb90c66 --- /dev/null +++ b/test/Dialect/TorchConversion/torch-backend-to-linalg-on-tensors-no-contract-check.mlir @@ -0,0 +1,24 @@ +// RUN: torch-mlir-opt -p 'builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline{verify=0})' -split-input-file %s | FileCheck %s + +// CHECK: func.func @tosa +func.func @tosa(%arg0: tensor) -> tensor { + // CHECK: tosa.abs + %1 = tosa.abs %arg0 : (tensor) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK: func.func @torch_gemm +func.func @torch_gemm(%arg0: tensor, %arg1: tensor<3x?xf32>, %arg2: tensor) -> (tensor {onnx.name = "gemm"}) attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch_c.from_builtin_tensor %arg0 : tensor -> !torch.vtensor<[?,3],f32> + %1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xf32> -> !torch.vtensor<[3,?],f32> + %2 = torch_c.from_builtin_tensor %arg2 : tensor -> !torch.vtensor<[?,?],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %3 = torch.aten.mm %0, %1 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[?,?],f32> + %4 = torch.aten.add.Tensor %3, %2, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + %5 = torch_c.to_builtin_tensor %4 : !torch.vtensor<[?,?],f32> -> tensor + %6 = tosa.abs %5 : (tensor) -> tensor + return %6 : tensor +} diff --git a/test/RefBackend/mlprogram-bufferize.mlir b/test/RefBackend/mlprogram-bufferize.mlir index bd8c2a6c0922..9e8065f57f1f 100644 --- a/test/RefBackend/mlprogram-bufferize.mlir +++ b/test/RefBackend/mlprogram-bufferize.mlir @@ -4,12 +4,12 @@ // CHECK-LABEL: func.func @forward() -> i64 { // CHECK: %[[CST127:.*]] = arith.constant 127 : i64 // CHECK: %[[GLOBAL_SEED:.*]] = memref.get_global @global_seed : memref -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref to tensor // CHECK: %[[SEED:.*]] = tensor.extract %[[TENSOR]][] : tensor // CHECK: %[[NEXT_SEED:.*]] = arith.muli %[[SEED]], %[[CST127]] : i64 // CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[TENSOR]][] : tensor // CHECK: %[[GLOBAL_SEED_1:.*]] = memref.get_global @global_seed : memref -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : memref +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : tensor to memref // CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_SEED_1]] : memref to memref // CHECK: return %[[NEXT_SEED]] : i64 module { diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 35d5558f8c93..6d4dcd602df8 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -18,6 +18,24 @@ # Configuration file for the 'lit' test runner. + +# Find path to the ASan runtime required for the Python interpreter. +def find_asan_runtime(): + if not "asan" in config.available_features or not "Linux" in config.host_os: + return "" + # Find the asan rt lib + return ( + subprocess.check_output( + [ + config.host_cxx.strip(), + f"-print-file-name=libclang_rt.asan-{config.host_arch}.so", + ] + ) + .decode("utf-8") + .strip() + ) + + # name: The name of this test suite. config.name = "TORCH_MLIR" @@ -66,10 +84,15 @@ "PATH", os.path.join(config.llvm_build_dir, "bin"), append_path=True ) +# Python configuration with sanitizer requires some magic preloading. This will only work on clang/linux. +# TODO: detect Darwin/Windows situation (or mark these tests as unsupported on these platforms). +if "asan" in config.available_features and "Linux" in config.host_os: + _asan_rt = find_asan_runtime() + config.python_executable = f"env LD_PRELOAD={_asan_rt} {config.python_executable}" # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in # llvm/utils/lit/lit/llvm/config.py. -if "Windows" in config.host_os: +elif "Windows" in config.host_os: config.python_executable = '"%s"' % (config.python_executable) tool_dirs = [ diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 1be54aaf6c15..a6d923fdfdc9 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -7,6 +7,8 @@ config.torch_mlir_obj_root = "@TORCH_MLIR_BINARY_DIR@" config.torch_mlir_python_packages_dir = "@TORCH_MLIR_PYTHON_PACKAGES_DIR@" config.torch_mlir_enable_refbackend = @TORCH_MLIR_ENABLE_REFBACKEND@ config.host_os = "@HOST_OS@" +config.host_cxx = "@HOST_CXX@" +config.host_arch = "@HOST_ARCH@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" diff --git a/test/python/compile.py b/test/python/compile.py index 32b47a25460f..ddb79b554ad7 100644 --- a/test/python/compile.py +++ b/test/python/compile.py @@ -1,10 +1,13 @@ -# RUN: %PYTHON -s %s 2>&1 | FileCheck %s +# RUN: %PYTHON %s 2>&1 | FileCheck %s import gc import sys import torch from torch_mlir import torchscript +# torchscript doesn't exist when TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS is OFF +# UNSUPPORTED: true + def run_test(f): print("TEST:", f.__name__, file=sys.stderr) @@ -34,5 +37,5 @@ def test_enable_ir_printing(): ) -# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) +# CHECK: // -----// IR Dump After Inliner (inline) # CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index fde318630077..be2235ec80bf 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -88,7 +88,13 @@ def forward(self, x): @run # CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes -# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,5],f32>) -> !torch.vtensor<[?,?,5],f32> +# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: %[[S1:.*]] = torch.symbolic_int "s1" {min_val = 2, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32> +# CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,5],f32> -> !torch.vtensor<[?,?,5],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 5)> : !torch.vtensor<[?,?,5],f32> +# CHECK: return %[[TANH]] : !torch.vtensor<[?,?,5],f32> def test_import_frozen_exported_program_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -97,10 +103,15 @@ def __init__(self): def forward(self, x): return torch.tanh(x) - batch = Dim("batch") - dynamic_shapes = {"x": {0: batch}} + batch = Dim("batch", max=10) + channel = Dim("channel", min=2) + dynamic_shapes = {"x": {0: batch, 1: channel}} m = fx.export_and_import( - Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net" + Basic(), + torch.randn(3, 4, 5), + dynamic_shapes=dynamic_shapes, + func_name="test_net", + import_symbolic_shape_expressions=True, ) print(m) @@ -108,6 +119,12 @@ def forward(self, x): @run # CHECK-LABEL: test_broadcast_with_dynamic_shapes # CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> +# CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: torch.aten.size.int +# CHECK: torch.prim.ListConstruct +# CHECK: %[[EXPAND:.*]] = torch.aten.expand +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> def test_broadcast_with_dynamic_shapes(): class Basic(nn.Module): def __init__(self): @@ -120,14 +137,19 @@ def forward(self, x, y): x = torch.randn(1, 2) y = torch.randn(10) - dim_0 = Dim("dim_0") + dim_0 = Dim("dim_0", max=10) dynamic_shapes = { "x": {}, "y": {0: dim_0}, } m = fx.export_and_import( - Basic(), x, y, dynamic_shapes=dynamic_shapes, func_name="test_net" + Basic(), + x, + y, + dynamic_shapes=dynamic_shapes, + func_name="test_net", + import_symbolic_shape_expressions=True, ) print(m) diff --git a/test/python/fx_importer/custom_op_test.py b/test/python/fx_importer/custom_op_test.py new file mode 100644 index 000000000000..9ce5820035b2 --- /dev/null +++ b/test/python/fx_importer/custom_op_test.py @@ -0,0 +1,133 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +import torch +import torch.nn as nn +from torch.export import Dim +from torch.library import Library, impl, impl_abstract + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat_custom_op(): + + m = Library("my_custom_library", "DEF") + m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor") + + @impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd") + def custom_op(x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + @impl_abstract("my_custom_library::tanh_sigmoid_cat_op") + def custom_op_meta(x, y, z): + result = custom_op(x, y, z) + return torch.empty_like(result) + + class TanhSigmoidCatCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1", max=50) + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCatCustomOp(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_custom_op_array_output +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,3],f32>) +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = 10} : !torch.int +# CHECK: %[[int:.+]] = torch.constant.int 4 +# CHECK: %[[V0:.+]] = torch.operator "torch.my_custom_library.array_output_op"(%[[int]], %[[ARG0]]) : (!torch.int, !torch.vtensor<[?,3],f32>) -> !torch.list +# CHECK: %[[V1:.+]]:4 = torch.prim.ListUnpack %[[V0]] : !torch.list -> !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#0, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#1, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#2, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[V1]]#3, [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[V1]]#0, %[[V1]]#1, %[[V1]]#2, %[[V1]]#3 : !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32>, !torch.vtensor<[?,3],f32> +def test_custom_op_array_output(): + m = Library("my_custom_library", "DEF") + m.define("array_output_op(int num_outs, Tensor a) -> Tensor[]") + + @impl(m, "array_output_op", "CompositeExplicitAutograd") + def custom_op(num_outs, a): + return [a] * num_outs + + @impl_abstract("my_custom_library::array_output_op") + def custom_op_meta(num_outs, a): + result = custom_op(num_outs, a) + return [torch.empty_like(t) for t in result] + + class ArrayOutputCustomOp(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + return torch.ops.my_custom_library.array_output_op(4, a) + + dim = Dim("n", max=10) + dynamic_shapes = { + "a": {0: dim}, + } + + a = torch.rand(2, 3) + m = fx.export_and_import( + ArrayOutputCustomOp(), + a, + import_symbolic_shape_expressions=True, + dynamic_shapes=dynamic_shapes, + ) + print(m) diff --git a/test/python/fx_importer/sparsity/lit.local.cfg b/test/python/fx_importer/sparsity/lit.local.cfg new file mode 100644 index 000000000000..274898b1438a --- /dev/null +++ b/test/python/fx_importer/sparsity/lit.local.cfg @@ -0,0 +1,10 @@ +config.unsupported = True + +try: + import torch + if "2.5.0" <= str(torch.__version__): + print("Enabling sparsity propagation tests") + config.unsupported = False + +except ModuleNotFoundError: + ... diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparsity/sparse_test.py similarity index 61% rename from test/python/fx_importer/sparse_test.py rename to test/python/fx_importer/sparsity/sparse_test.py index 474fe2bfddbc..26e908bb59c4 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparsity/sparse_test.py @@ -5,15 +5,17 @@ # RUN: %PYTHON %s | FileCheck %s +# torch_mlir_e2e_test is not available downstream. +# UNSUPPORTED: true + from typing import Any, Callable, Optional, Tuple, Dict import torch -import torch.export import torch.nn as nn import numpy as np +from torch_mlir.extras.fx_decomp_util import get_decomposition_table from torch_mlir.extras.fx_importer import FxImporter -from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -22,127 +24,15 @@ ) -# All sparse layouts currently supported in torch.sparse. -SPARSE_LAYOUTS = [ - torch.sparse_coo, - torch.sparse_csr, - torch.sparse_csc, - torch.sparse_bsr, - torch.sparse_bsc, -] - - -def sparse_metadata(a: torch.Tensor) -> SparsityMeta: - """ - Returns a meta data tuple for the given sparse tensor. - - NOTE: this will be fully replaced by fx graph SparseTensorMetadata - """ - sparse_dim = a.sparse_dim() - dense_dim = a.dense_dim() - batch_dim = a.ndim - dense_dim - sparse_dim - blocksize = None - if a.layout is torch.sparse_coo: - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a._indices().dtype, - a._indices().dtype, - ) - elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: - if a.layout is torch.sparse_bsr: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.crow_indices().dtype, - a.col_indices().dtype, - ) - elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: - if a.layout is torch.sparse_bsc: - blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] - return SparsityMeta( - a.layout, - batch_dim, - sparse_dim, - dense_dim, - blocksize, - a.ccol_indices().dtype, - a.row_indices().dtype, - ) - else: - raise RuntimeError(f"Unsupported sparse layout for {a}") - - -def sparse_export( - f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None -) -> torch.export.ExportedProgram: - """ - This is a ***temporary*** wrapper around `torch.export.export` - that eventually should be removed and simply replaced by the - standard API for exporting traced graphs. - - But until issue - - https://github.com/pytorch/pytorch/pull/117907 - - is addressed, this wrapper provides support for the sparse - tensor types by first converting all operands to dense tensors, - building the traced graph as for the dense case, then annotating - sparse parameters with their actual sparse layout attributes, - followed by some simple propagation rules. This temporary solution - accelerates testing torch-mlir with PyTorch sparse tensors until - the issue is resolved upstream. - """ - # Convert all arguments to dense. - dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) - mask = [a.layout in SPARSE_LAYOUTS for a in args] - # Build the regular FX traced graph with only dense arguments - # (the current version would crash otherwise, see issue above). - prog = torch.export.export(f, dargs, kwargs) - # Annotate sparse arguments in the graph and apply some very - # basic propagation rules for sparsity. - specs = prog.graph_signature.input_specs - alen = len(specs) - k = 0 - for i, node in enumerate(prog.graph.nodes): - if node.op == "placeholder": - # Argument. - spec = specs[i] - if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: - if mask[k]: - node.meta["sparsity"] = sparse_metadata(args[k]) - k = k + 1 - elif node.op == "call_function": - # TODO: use upstream _opname implementation when available - opname = node.target._schema.name.split("::")[1] - # Zero preserving elt-wise unary op. - if opname in {"abs", "neg", "relu", "sin"}: - node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif opname == "_to_sparse": - dim = len(node.meta.get("val").shape) - node.meta["sparsity"] = SparsityMeta( - torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 - ) - # TODO: Uncomment this to hack sparsity into the network. - # elif opname == "_to_dense": - # # hack (assumes we never really want the to_dense for now) - # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - return prog - - def export_and_import(f, *args, **kwargs): - """This method implements Stella's importer, stripped down to essentials.""" + """A FX graph importer, stripped down to essentials.""" context = ir.Context() torch_d.register_dialect(context) fx_importer = FxImporter(context=context) - prog = sparse_export(f, args, kwargs) + prog = torch.export.export(f, args, kwargs) + decomposition_table = get_decomposition_table() + if decomposition_table: + prog = prog.run_decompositions(decomposition_table) fx_importer.import_frozen_program(prog) return fx_importer.module @@ -162,12 +52,13 @@ def sparse_jit(f, *args, **kwargs): enable_ir_printing=False, ) # Compile with reference Linalg backend. - backend = RefBackendLinalgOnTensorsBackend() + # TODO: runtime verification ails with 'rank mismatch' on memref.cast + backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False) compiled = backend.compile(module) invoker = backend.load(compiled) xargs = [] - # Prepare the buffer parameters (assume all dense). - # TODO: filters out scalar arguments, anything else? + # Prepare all the named buffer parameters (assume all dense). + # All scalar arguments are filtered out since they appear inline. params = dict(f.named_buffers(remove_duplicate=True)) params_flat, params_spec = torch.utils._pytree.tree_flatten(params) for p in params_flat: @@ -203,7 +94,8 @@ def sparse_jit(f, *args, **kwargs): def run(f): - print(f"{f.__name__}") + # Prompt test name and torch version (for debugging). + print(f"{f.__name__} ({torch.__version__})") print("-" * len(f.__name__)) f() print() @@ -327,25 +219,25 @@ def forward(self, x, v): print("torch.mlir =", res2) -@run -# -# CHECK-LABEL: test_sparse_SpMM -# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, -# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { -# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> -# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> -# CHECK: } +# @run # -# CHECK: torch.sparse -# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], -# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], -# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) -# CHECK: torch.mlir -# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] -# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] -# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} +# C_HECK-LABEL: test_sparse_SpMM +# C_HECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, +# C_HECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { +# C_HECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# C_HECK: return %[[R]] : !torch.vtensor<[8,8],f32> +# C_HECK: } +## +# C_HECK: torch.sparse +# C_HECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], +# C_HECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], +# C_HECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) +# C_HECK: torch.mlir +# C_HECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] +# C_HECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] +# C_HECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} # def test_sparse_SpMM(): class MatMulNet(torch.nn.Module): @@ -370,40 +262,40 @@ def forward(self, x, y): print(res2) -@run +# @run # -# CHECK-LABEL: test_sparse_eltwise -# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> { -# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -# CHECK: } -# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> { -# CHECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -# CHECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -# CHECK: } +# C_HECK-LABEL: test_sparse_eltwise +# C_HECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> { +# C_HECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> -> !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> +# C_HECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$CSRD]]> +# C_HECK: } +# C_HECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : batch, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[4,2,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> { +# C_HECK: %[[R:.*]] = torch.aten.neg %[[A]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> -> !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> +# C_HECK: return %[[R]] : !torch.vtensor<[4,2,2],f32,#[[$BCSR]]> +# C_HECK: } # -# CHECK: torch.sparse -# CHECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]), -# CHECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]), -# CHECK: values=tensor({{\[}}[ -1., -2.], -# CHECK: [ -3., -4.], -# CHECK: [ -5., -6.], -# CHECK: [ -7., -8.], -# CHECK: [ -9., -10.], -# CHECK: [-11., -12.], -# CHECK: [-13., -14.], -# CHECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8, -# CHECK: layout=torch.sparse_csr) -# CHECK: torch.mlir -# CHECK: [0 2 4 6 8] -# CHECK: [0 1 0 1 0 1 0 1] -# CHECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14. -# CHECK: -15. -16.] -# CHECK: torch.mlir.batch +# C_HECK: torch.sparse +# C_HECK: tensor(crow_indices=tensor([0, 2, 4, 6, 8]), +# C_HECK: col_indices=tensor([0, 1, 0, 1, 0, 1, 0, 1]), +# C_HECK: values=tensor({{\[}}[ -1., -2.], +# C_HECK: [ -3., -4.], +# C_HECK: [ -5., -6.], +# C_HECK: [ -7., -8.], +# C_HECK: [ -9., -10.], +# C_HECK: [-11., -12.], +# C_HECK: [-13., -14.], +# C_HECK: [-15., -16.]{{\]}}), size=(4, 2, 2), nnz=8, +# C_HECK: layout=torch.sparse_csr) +# C_HECK: torch.mlir +# C_HECK: [0 2 4 6 8] +# C_HECK: [0 1 0 1 0 1 0 1] +# C_HECK: [ -1. -2. -3. -4. -5. -6. -7. -8. -9. -10. -11. -12. -13. -14. +# C_HECK: -15. -16.] +# C_HECK: torch.mlir.batch # def test_sparse_eltwise(): class EltNet(torch.nn.Module): @@ -431,7 +323,7 @@ def forward(self, x): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) - # TODO: make this work + # TODO: make this work in MLIR # res3 = sparse_jit(net, batch_input) print("torch.sparse") print(res1) @@ -459,6 +351,11 @@ def forward(self, x): # CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), # CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) # CHECK: torch.mlir +# CHECK: [0 6] +# CHECK: [0 1 1 4 9 9] +# CHECK: [ 0 1 1 5 19 19] +# CHECK: [ 0 1 3 6 28 29] +# CHECK: [ 0. 0. 1. 2. 3. 1000.] # def test_sparse_coo3(): class COO3Net(torch.nn.Module): @@ -481,11 +378,15 @@ def forward(self, x): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) - # TODO: make coo3 work - # res2 = sparse_jit(net, sparse_input) + res2 = sparse_jit(net, sparse_input) print("torch.sparse") print(res1) print("torch.mlir") + print(res2[0]) + print(res2[1]) + print(res2[2]) + print(res2[3]) + print(res2[4]) @run @@ -497,7 +398,7 @@ def forward(self, x): # CHECK: %[[N1:.*]] = torch.constant.none # CHECK: %[[N2:.*]] = torch.constant.none # CHECK: %[[N3:.*]] = torch.constant.none -# CHECK: %[[R:.*]] = torch.operator "torch.aten._to_sparse"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> +# CHECK: %[[R:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}"(%[[A]], %[[N1]], %[[N2]], %[[N3]]) : (!torch.vtensor<[2,2,2],f32>, !torch.none, !torch.none, !torch.none) -> !torch.vtensor<[2,2,2],f32,#[[$COO]]> # CHECK: return %[[R]] : !torch.vtensor<[2,2,2],f32,#[[$COO]]> # CHECK: } # @@ -537,20 +438,20 @@ def forward(self, x): print(res2[4]) -@run +# @run # -# CHECK-LABEL: test_sparse_network -# CHECK: func.func @main( -# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> { +# C_HECK-LABEL: test_sparse_network +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[2,3,8,8],f32>) -> !torch.vtensor<[8],f32> { # ... lots of IR ... -# CHECK-COUNT-15: torch.aten.mul.Tensor +# C_HECK-COUNT-15: torch.aten.mul.Tensor # ... lots of IR ... -# CHECK: } +# C_HECK: } # -# CHECK: torch.sparse -# CHECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) -# CHECK: torch.mlir -# CHECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] +# C_HECK: torch.sparse +# C_HECK: tensor([ 0., 11., 9., 11., 13., 11., 10., 12.]) +# C_HECK: torch.mlir +# C_HECK: [ 0. 11. 9. 11. 13. 11. 10. 12.] # def test_sparse_network(): def spike(input): @@ -574,8 +475,8 @@ def forward(self, X): for t in range(T): mem = mem * self.decay + X[..., t] spike = self.act(mem - self.thresh) - mem = mem * (1.0 - spike) spike = spike.to_sparse().to_dense() # prop hack + mem = mem * (1.0 - spike) spike_pot.append(spike) spike_pot = torch.stack(spike_pot, dim=-1) return spike_pot @@ -621,3 +522,119 @@ def forward(self, X): print(res1) print("torch.mlir") print(res2) + + +# @run +# +# C_HECK-LABEL: test_sparse_feature_scaling +# C_HECK: func.func @main( +# C_HECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +# ... more IR ... +# C_HECK: %[[D:.*]] = torch.operator "torch.aten.{{to_sparse|_to_sparse}}" +# C_HECK: %[[R:.*]] = torch.aten.{{matmul|mm}} %[[D]], %[[A]] +# C_HECK return %[[R]] : !torch.vtensor<[4,4],f32> +# C_HECK: } +# +# C_HECK: torch.sparse +# C_HECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], +# C_HECK: [0.1321, 0.2724, 0.2105, 0.3851], +# C_HECK: [0.2478, 0.3439, 0.1898, 0.2185], +# C_HECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) +# +# TODO: first row looks suspect... +# +# C_HECK: torch.mlir +# C_HECK: {{\[}}[0. 0. 0. 0. ] +# C_HECK: [0.13205223 0.27236593 0.21051763 0.38506418] +# C_HECK: [0.24781987 0.34391665 0.18976606 0.2184974 ] +# C_HECK: [0.02224578 0.16825409 0.29283574 0.51666445]{{\]}} +# +def test_sparse_feature_scaling(): + class Scale(nn.Module): + def forward(self, F): + sum_vector = torch.sum(F, dim=1) + reciprocal_vector = 1 / sum_vector + reciprocal_vector[reciprocal_vector == float("inf")] = 0 + scaling_diagonal = torch.diag(reciprocal_vector).to_sparse() + return scaling_diagonal @ F + + net = Scale() + + # Get a random (but reproducible) features input. + torch.manual_seed(0) + f = torch.rand(4, 4) + m = export_and_import(net, f) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(f) + res2 = sparse_jit(net, f) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) + + +@run +# +# CHECK-LABEL: test_sparse_gcn +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>, +# CHECK-SAME: %[[B:.*]]: !torch.vtensor<[4,4],f32,#[[$COO]]>) -> !torch.vtensor<[4,4],f32> { +# CHECK: %[[LIT:.*]] = torch.vtensor.literal(dense_resource : tensor<4x4xf32>) : !torch.vtensor<[4,4],f32> +# CHECK: %[[MM:.*]] = torch.aten.mm %[[A]], %[[LIT]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32> +# CHECK: %[[SMM:.*]] = torch.aten.mm %[[B]], %[[MM]] : !torch.vtensor<[4,4],f32,#sparse>, !torch.vtensor<[4,4],f32> -> !torch.vtensor<[4,4],f32> +# CHECK: %[[BIAS:.*]] = torch.vtensor.literal(dense_resource : tensor<4xf32>) : !torch.vtensor<[4],f32> +# CHECK: %[[ONE:.*]] = torch.constant.int 1 +# CHECK: %[[R:.*]] = torch.aten.add.Tensor %[[SMM]], %[[BIAS]], %[[ONE]] : !torch.vtensor<[4,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[4,4],f32> +# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[4.4778, 4.4778, 4.4778, 4.4778], +# CHECK: [5.7502, 5.7502, 5.7502, 5.7502], +# CHECK: [4.6980, 4.6980, 4.6980, 4.6980], +# CHECK: [3.6407, 3.6407, 3.6407, 3.6407]{{\]}}) +# CHECK: torch.mlir +# CHECK: {{\[}}[4.477828 4.477828 4.477828 4.477828 ] +# CHECK: [5.7501717 5.7501717 5.7501717 5.7501717] +# CHECK: [4.697952 4.697952 4.697952 4.697952 ] +# CHECK: [3.640687 3.640687 3.640687 3.640687 ]{{\]}} +# +def test_sparse_gcn(): + class GraphConv(nn.Module): + def __init__(self, input_dim, output_dim): + super(GraphConv, self).__init__() + self.kernel = nn.Parameter(torch.Tensor(input_dim, output_dim)) + nn.init.ones_(self.kernel) + self.bias = nn.Parameter(torch.Tensor(output_dim)) + nn.init.ones_(self.bias) + + def forward(self, inp, adj_mat): + # Input matrix times weight matrix. + support = torch.mm(inp, self.kernel) + # Sparse adjacency matrix times support matrix. + output = torch.spmm(adj_mat, support) + # Add bias. + output = output + self.bias + return output + + net = GraphConv(4, 4) + + # Get a random (but reproducible) matrices. + torch.manual_seed(0) + inp = torch.rand(4, 4) + adj_mat = torch.rand(4, 4).to_sparse() + m = export_and_import(net, inp, adj_mat) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + # Set to inference mode to avoid autograd component in result. + with torch.no_grad(): + res1 = net(inp, adj_mat) + res2 = sparse_jit(net, inp, adj_mat) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py new file mode 100644 index 000000000000..3b8274ccae46 --- /dev/null +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -0,0 +1,501 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests of various op special forms that the fx_importer +# handles. + +import torch +import torch.export +import torch.nn as nn +from torch.export import Dim + +from torch_mlir import fx + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_tanh_sigmoid_cat +# CHECK: func.func @main( +# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>, +# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int +# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int +# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list +# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list, !torch.int -> !torch.vtensor<[?,?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32> +# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32> +def test_tanh_sigmoid_cat(): + class TanhSigmoidCat(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z): + a = torch.tanh(x) + b = torch.sigmoid(y) + return torch.cat((a, a, b, z), dim=1) + + # Sample inputs + x = torch.randn(5, 2, 3) + y = torch.randn(5, 6, 3) + z = torch.randn(5, 4, 3) + + # Dynamic dim constraints + dim_n = Dim("n", min=5, max=10) + dim_x1 = Dim("x1", max=100) + dim_y1 = Dim("y1", max=50) + dim_z1 = Dim("z1", max=50) + dynamic_shapes = { + "x": {0: dim_n, 1: dim_x1}, + "y": {0: dim_n, 1: dim_y1}, + "z": {0: dim_n, 1: dim_z1}, + } + + m = fx.export_and_import( + TanhSigmoidCat(), + x, + y, + z, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_symbolic_dim_differ_by_one +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 6} : !torch.int +# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# CHECK-DISABLED: %[[S1:.+]] = torch.symbolic_int "s0 + 1" {min_val = 4, max_val = 7} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0 + 1)> : !torch.vtensor<[?],f32> +# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %arg1, {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[ADD:.+]] = torch.aten.add.Tensor %[[ARG0]], %[[SLICE]], {{.*}} : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[ADD]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[ADD]] : !torch.vtensor<[?],f32> +def test_symbolic_dim_differ_by_one(): + class SymbolicDimDifferByOne(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x + y[1:] + + # Sample inputs + x = torch.randn(5) + y = torch.randn(6) + + # Dynamic dim constraints + dimx = Dim("dimx", min=3, max=6) + dimy = dimx + 1 + dynamic_shapes = { + "x": {0: dimx}, + "y": {0: dimy}, + } + + m = fx.export_and_import( + SymbolicDimDifferByOne(), + x, + y, + dynamic_shapes=dynamic_shapes, + experimental_support_mutation=True, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_outer_with_squared_shape +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[I0:.+]] = torch.constant.int 0 +# CHECK: %[[SIZE:.+]] = torch.aten.size.int %[[ARG0]], %[[I0]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.int +# The Torch 2.6 generates `torch.aten.outer` as an op in this example while the torch versions < 2.6 does not, hence this check is kept as a "COM". +# COM: %[[OUTER:.+]] = torch.aten.outer %[[ARG0]], %[[ARG0]] : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %{{.*}}, [%[[S0]]], affine_map<()[s0] -> (s0, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[SIZE]], %[[SIZE]] : !torch.int, !torch.int -> !torch.int +# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MUL]] : (!torch.int) -> !torch.list +# CHECK: %[[VIEW:.+]] = torch.aten.view %{{.*}}, %[[LIST]] : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S0]]], affine_map<()[s0] -> (s0 * s0)> : !torch.vtensor<[?],f32> +# CHECK: return %[[VIEW]] : !torch.vtensor<[?],f32> +def test_outer_with_squared_shape(): + class OuterWithSquaredShape(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.outer(x, x).flatten() + + # Sample inputs + x = torch.rand(10) + + # Dynamic dim constraints + batch = Dim("batch", max=10) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + OuterWithSquaredShape(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_slice_tensor_static_output +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[2,1],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 10} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[SLICE1:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> +# CHECK: %[[SLICE2:.+]] = torch.aten.slice.Tensor %[[SLICE1]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> +# CHECK: return %[[SLICE2]] : !torch.vtensor<[2,1],f32> +def test_slice_tensor_static_output(): + class SliceTensorStaticOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x[0:2, :1] + + # Sample inputs + x = torch.randn(4, 3) + + # Dynamic dim constraints + batch = Dim("batch", min=3, max=10) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + SliceTensorStaticOutput(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_slice_tensor_dynamic_output +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[SLICE:.+]] = torch.aten.slice.Tensor %[[ARG0]], {{.*}}, {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +# CHECK: torch.bind_symbolic_shape %[[SLICE]], [%[[S0]]], affine_map<()[s0] -> (s0 - 5)> : !torch.vtensor<[?],f32> +# CHECK: return %[[SLICE]] : !torch.vtensor<[?],f32> +def test_slice_tensor_dynamic_output(): + class SliceTensorDynamicOutput(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[5:] + + # Sample inputs + x = torch.randn(10) + + # Dynamic dim constraints + dimx = Dim("dimx", min=5, max=10) + dynamic_shapes = {"x": {0: dimx}} + + m = fx.export_and_import( + SliceTensorDynamicOutput(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_div_tensor_mixed_ranks +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[DIV:.+]] = torch.aten.div.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,3],f32> +# CHECK: torch.bind_symbolic_shape %[[DIV]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: return %[[DIV]] : !torch.vtensor<[?,3],f32> +def test_div_tensor_mixed_ranks(): + class DivTensorMixedRanks(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + div = torch.div(x, y) + return div + + # Sample inputs + x = torch.tensor(10.0) + y = torch.randn(2, 3) + + # Dynamic dim constraints + batch = Dim("batch", max=10) + dynamic_shapes = {"x": None, "y": {0: batch}} + + m = fx.export_and_import( + DivTensorMixedRanks(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_shape_div +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,7],f32>) -> !torch.vtensor<[?,5],f32> { +# FIXME: This appears in torch-nightly, but not in torch-stable (re-enable once we've moved torch-stable to 2.4+) +# CHECK-DISABLED: %[[S0:.+]] = torch.symbolic_int "5*s1" {min_val = 0, max_val = 5000} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = 2, max_val = 1000} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S1]]], affine_map<()[s0] -> (s0 * 5, 7)> : !torch.vtensor<[?,7],f32> +# CHECK: %[[VIEW:.+]] = torch.aten.view %[[ARG0]], {{.*}} : !torch.vtensor<[?,7],f32>, !torch.list -> !torch.vtensor<[?,5],f32> +# CHECK: torch.bind_symbolic_shape %[[VIEW]], [%[[S1]]], affine_map<()[s0] -> (s0 * 7, 5)> : !torch.vtensor<[?,5],f32> +# CHECK: return %[[VIEW]] : !torch.vtensor<[?,5],f32> +def test_shape_div(): + class ShapeDiv(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 5) + + # Sample inputs + x = torch.rand(10, 7) + + # Dynamic dim constraints + batch = Dim("batch", max=1000) * 5 + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + ShapeDiv(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>) -> !torch.vtensor<[3,?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,?],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (3, s0)> : !torch.vtensor<[3,?],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[3,?],f32> +def test_broadcast_unit_dim_to_static_with_unchanged_dim_dynamic(): + class BroadcastUnitDimToStaticWithUnchangedDimDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (3, -1)) + + # Sample inputs + x = torch.randn(1, 2) + + # Dynamic dim constraints + dim_1 = Dim("dim_1", max=10) + dynamic_shapes = {"x": {1: dim_1}} + + m = fx.export_and_import( + BroadcastUnitDimToStaticWithUnchangedDimDynamic(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,2],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,2],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],f32> +# CHECK: return %3 : !torch.vtensor<[?,2],f32> +def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_static(): + class BroadcastUnitDimToDynamicWithUnchangedDimStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + # Dynamic dim constraints + dim_0 = Dim("dim_0", max=10) + dynamic_shapes = {"x": {}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithUnchangedDimStatic(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,?],f32>, %[[ARG1:.+]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (1, s0)> : !torch.vtensor<[1,?],f32> +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S1]]], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,?],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,?],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s1, s0)> : !torch.vtensor<[?,?],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,?],f32> +def test_broadcast_unit_dim_to_dynamic_with_unchanged_dim_dynamic(): + class BroadcastUnitDimToDynamicWithUnchangedDimDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, (y.shape[0], -1)) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(10) + + # Dynamic dim constraints + dim_0 = Dim("dim_0", max=10) + dim_1 = Dim("dim_1", max=10) + dynamic_shapes = {"x": {1: dim_1}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithUnchangedDimDynamic(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_broadcast_unit_dim_to_dynamic_with_rank_increase +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[1,2],f32>, %[[ARG1:.+]]: !torch.vtensor<[?,3,2],f32>) -> !torch.vtensor<[?,3,2],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32> +# CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[ARG0]], {{.*}}, {{.*}} : !torch.vtensor<[1,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[?,3,2],f32> +# CHECK: torch.bind_symbolic_shape %[[EXPAND]], [%[[S0]]], affine_map<()[s0] -> (s0, 3, 2)> : !torch.vtensor<[?,3,2],f32> +# CHECK: return %[[EXPAND]] : !torch.vtensor<[?,3,2],f32> +def test_broadcast_unit_dim_to_dynamic_with_rank_increase(): + class BroadcastUnitDimToDynamicWithRankIncrease(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.broadcast_to(x, y.size()) + + # Sample inputs + x = torch.randn(1, 2) + y = torch.randn(4, 3, 2) + + # Dynamic dim constraints + dim_0 = Dim("dim_0", max=25) + dynamic_shapes = {"x": {}, "y": {0: dim_0}} + + m = fx.export_and_import( + BroadcastUnitDimToDynamicWithRankIncrease(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_gather_elements +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.+]]: !torch.vtensor<[2,3],si64>) -> !torch.vtensor<[2,3],f32> { +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 100} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[GATHER:.+]] = torch.aten.gather %[[ARG0]], {{.*}}, {{.*}}, {{.*}} : !torch.vtensor<[?,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.bool -> !torch.vtensor<[2,3],f32> +# CHECK: return %[[GATHER]] : !torch.vtensor<[2,3],f32> +def test_gather_elements(): + class GatherElements(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.gather(x, 0, y) + + # Sample inputs + x = torch.randn(4, 3) + y = torch.tensor([[0, 0, 0], [1, 1, 1]]) + + # Dynamic dim constraints + batch = Dim("batch", min=3, max=100) + dynamic_shapes = {"x": {0: batch}, "y": {}} + + m = fx.export_and_import( + GatherElements(), + x, + y, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) + + +@run +# CHECK-LABEL: test_nonzero +# CHECK: func.func @main(%[[ARG0:.+]]: !torch.vtensor<[?,3],f32>) -> !torch.vtensor<[?,2],si64> { +# FIXME: There's a bug in the torch 2.3 stable release which creates redundant symbolic_int ops for the nonzero +# output which is fixed in the 2.4 nightlies. Once we move to a 2.4 stable release, this check may be re-enabled +# CHECK-DISABLED: %[[U0:.+]] = torch.symbolic_int "u0" {min_val = 0, max_val = 9223372036854775806} : !torch.int +# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 3, max_val = 10} : !torch.int +# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]]], affine_map<()[s0] -> (s0, 3)> : !torch.vtensor<[?,3],f32> +# CHECK: %[[NZERO:.+]] = torch.aten.nonzero %[[ARG0]] : !torch.vtensor<[?,3],f32> -> !torch.vtensor<[?,2],si64> +# CHECK-DISABLED: torch.bind_symbolic_shape %[[NZERO]], [%[[U0]]], affine_map<()[s0] -> (s0, 2)> : !torch.vtensor<[?,2],si64> +# CHECK: return %[[NZERO]] : !torch.vtensor<[?,2],si64> +def test_nonzero(): + class Nonzero(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nonzero(x) + + # Sample inputs + x = torch.randn(4, 3) + + # Dynamic dim constraints + batch = Dim("batch", min=3, max=10) + dynamic_shapes = {"x": {0: batch}} + + m = fx.export_and_import( + Nonzero(), + x, + dynamic_shapes=dynamic_shapes, + import_symbolic_shape_expressions=True, + ) + print(m) diff --git a/test/python/fx_importer/sympy_to_affine_expr_test.py b/test/python/fx_importer/sympy_to_affine_expr_test.py new file mode 100644 index 000000000000..0c366040d216 --- /dev/null +++ b/test/python/fx_importer/sympy_to_affine_expr_test.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s +# This file contains tests checking translating sympy expressions to (semi-)affine expressions. + +from sympy import Symbol +from torch_mlir.extras.fx_importer import sympy_expr_to_semi_affine_expr + +from torch_mlir.ir import ( + AffineSymbolExpr, + Context, +) + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_sympy_to_semi_affine_expr_translation +def test_sympy_to_semi_affine_expr_translation(): + with Context(): + s0 = Symbol("s0", positive=True, integer=True) + s1 = Symbol("s1", positive=True, integer=True) + + symbols_set = sorted({s0, s1}, key=lambda x: x.name) + symbols_map = { + str(symbol): AffineSymbolExpr.get(i) for i, symbol in enumerate(symbols_set) + } + + SYMPY_EXPRS = [ + # CHECK: 10 + (10), + # CHECK: s0 + (s0), + # CHECK: s0 + (s0 + 0), + # CHECK: s0 + 1 + (s0 + 1), + # CHECK: s0 + (s0 * 1), + # CHECK: s0 * 2 + (s0 * 2), + # CHECK: s0 * s0 + (s0 * s0), + # CHECK: s0 * s1 + (s0 * s1), + # CHECK: s0 * s0 + (s0**2), + # CHECK: (s0 * s0) * s0 + (s0**3), + # CHECK: ((((s0 * s0) * s0) * s0) * s0) * s0 + ((s0**2) ** 3), + # CHECK: ((((((s0 * s0) * s0) * s0) * s0) * s0) * s0) * s0 + (s0 ** (2**3)), + # CHECK: s0 mod 10 + (s0 % 10), + # CHECK: s0 - s1 * 2 + 5 + (s0 + 5 - 2 * s1), + ] + + for expr in SYMPY_EXPRS: + print(sympy_expr_to_semi_affine_expr(expr, symbols_map)) diff --git a/test/python/fx_importer/v2.3/auto_functionalized.py b/test/python/fx_importer/v2.3/auto_functionalized.py index ab7401dcc2fb..7fb0eeb3b67f 100644 --- a/test/python/fx_importer/v2.3/auto_functionalized.py +++ b/test/python/fx_importer/v2.3/auto_functionalized.py @@ -59,8 +59,9 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> - # CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]] + # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". + # COM: torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> () + # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} print(m) m.operation.verify() @@ -86,7 +87,8 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) - # CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0 + # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". + # COM: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} print(m) m.operation.verify() diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index c62b12706e58..c2e5d9f14e2f 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -65,7 +65,9 @@ def forward(self, x): # CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.tensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[arg1_copy:.+]] = torch.copy.to_vtensor %arg1 : !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[arg1_mul:.+]] = torch.aten.mul.Tensor %[[arg1_copy]], %arg0 -# CHECK-DAG: torch.overwrite.tensor.contents %[[arg1_mul]] overwrites %arg1 +# The Torch 2.6 generates `torch.aten.copy` as an op in this example while the torch versions < 2.6 does not, hence this check is kept as a "COM". +# COM: %{{.*}} = torch.aten.copy %[[arg1_copy]], %[[arg1_mul]], %false : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.bool -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: torch.overwrite.tensor.contents %{{.*}} overwrites %arg1 # CHECK-DAG: %[[arg0_mul:.+]] = torch.aten.mul.Tensor %arg0, %[[arg1_mul]] # CHECK: return %[[arg0_mul]] def test_user_input_mutate(): @@ -105,6 +107,27 @@ def forward(self, x): m.operation.verify() +@run +# CHECK-LABEL: test_frozen_buffer_non_persistent +# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0 +# CHECK: return %[[mul]] +def test_frozen_buffer_non_persistent(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4), persistent=False) + + def forward(self, x): + return x * self.buffer + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + class ExternalBufferHooks(fx.FxImporterHooks): def prepare_module(self, module_op: Operation): module_op.context.allow_unregistered_dialects = True diff --git a/test/python/fx_importer/v2.3/types_test.py b/test/python/fx_importer/v2.3/types_test.py index 19dee8b7b2cb..cb897a8c88bd 100644 --- a/test/python/fx_importer/v2.3/types_test.py +++ b/test/python/fx_importer/v2.3/types_test.py @@ -36,8 +36,13 @@ def forward(self, x): x = x + 1.0 return x.shape[0] + # CHECK: %[[S0:.*]] = torch.symbolic_int "s0" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int + # CHECK: torch.bind_symbolic_shape %arg0, [%[[S0]]], affine_map<()[s0] -> (s0, 4)> : !torch.vtensor<[?,4],f32> # CHECK: torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,4],f32>, !torch.int -> !torch.int m = fx.export_and_import( - Basic(), torch.randn(3, 4), dynamic_shapes={"x": {0: torch.export.Dim("b")}} + Basic(), + torch.randn(3, 4), + dynamic_shapes={"x": {0: torch.export.Dim("b", min=3, max=10)}}, + import_symbolic_shape_expressions=True, ) print(m) diff --git a/test/python/onnx_importer/BadName.onnx b/test/python/onnx_importer/BadName.onnx new file mode 100644 index 000000000000..b63cda401726 Binary files /dev/null and b/test/python/onnx_importer/BadName.onnx differ diff --git a/test/python/onnx_importer/BadName.runlit b/test/python/onnx_importer/BadName.runlit new file mode 100644 index 000000000000..3ae08941e8a8 --- /dev/null +++ b/test/python/onnx_importer/BadName.runlit @@ -0,0 +1,5 @@ +# The original constant name : "abz_.(1, 2)[$something, %anotherthing]" + +# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/BadName.onnx | FileCheck %s + +# CHECK: torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_abz_._1__2___something___anotherthing_> diff --git a/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit b/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit new file mode 100644 index 000000000000..dd67aadabde8 --- /dev/null +++ b/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit @@ -0,0 +1,18 @@ +# Test that expansion of ONNX operators that are functions works for a simple +# example. The exact name mangling scheme used is not matched against, all that +# matters is that it has the name of the operator (GreaterOrEqual here) in it. +# Attributes are also not checked here. What we are interested in is the types +# and operations. +# +# The model comes from an upstream ONNX test: backend/test/data/node/test_greater_equal/model.onnx + +# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s + +# CHECK-LABEL: func.func @test_greater_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> +# CHECK: %0 = call @"{{.*}}GreaterOrEqual{{.*}}"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + +# CHECK-LABEL: func.func private @"{{.*}}GreaterOrEqual{{.*}}"(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> +# CHECK: %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> +# CHECK: %1 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> +# CHECK: %2 = torch.operator "onnx.Or"(%0, %1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> +# CHECK: return %2 : !torch.vtensor<[3,4,5],i1> diff --git a/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit.onnx b/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit.onnx new file mode 100644 index 000000000000..061aed0d57fd Binary files /dev/null and b/test/python/onnx_importer/function_expansion/GreaterOrEqual.runlit.onnx differ diff --git a/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit b/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit new file mode 100644 index 000000000000..84e0cac63c7b --- /dev/null +++ b/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit @@ -0,0 +1,22 @@ +# Test the expansion of ONNX operators that are functions, specifically the +# propagation of attribute values from the call-site to nodes within the +# expanded function. +# +# In this case, the model has a ReduceSumSquare node with the attribute +# 'keepdims' set to 0, and the definition of this version of ReduceSumSquare +# contains a ReduceSum node that references the value of 'keepdims', so we +# expect to see this value propagated to the ReduceSum node in the expansion. +# +# This also tests that the absence of 'axes' (as an optional attribute with no +# default value) is propagated in the same way. +# +# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_do_not_keepdims_example/model.onnx + +# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s +# +# CHECK-LABEL: func.func @test_reduce_sum_square_do_not_keepdims_example +# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}" +# +# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}" +# CHECK: %0 = torch.operator "onnx.Mul" +# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 0 : si64} diff --git a/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit.onnx b/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit.onnx new file mode 100644 index 000000000000..cfdc1b20352e Binary files /dev/null and b/test/python/onnx_importer/function_expansion/ReduceSumSquare_keepdims=0.runlit.onnx differ diff --git a/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit b/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit new file mode 100644 index 000000000000..a5dcdd9c5ad5 --- /dev/null +++ b/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit @@ -0,0 +1,23 @@ +# Test the expansion of ONNX operators that are functions, specifically the +# propagation of attribute values from the call-site to nodes within the +# expanded function. +# +# In this case, the model has a ReduceSumSquare node with no attributes, but the +# definition of this version of ReduceSumSquare contains a ReduceSum node that +# references the value of 'keepdims', and the definition says its default value +# is 1, so we expect to see this value propagated to the ReduceSum node in the +# expansion. +# +# This also tests that the absence of 'axes' (as an optional attribute with no +# default value) is propagated in the same way. +# +# The model comes from an upstream ONNX test: backend/test/data/node/test_reduce_sum_square_empty_set/model.onnx + +# RUN: %PYTHON -m torch_mlir.tools.import_onnx --disable-function-expansion-allowlist %s.onnx | FileCheck %s +# +# CHECK-LABEL: func.func @test_reduce_sum_square_empty_set +# CHECK: %0 = call @"{{.*}}ReduceSumSquare{{.*}}" +# +# CHECK-LABEL: func.func private @"{{.*}}ReduceSumSquare{{.*}}" +# CHECK: %0 = torch.operator "onnx.Mul" +# CHECK: %1 = torch.operator "onnx.ReduceSum"{{.*}}{torch.onnx.keepdims = 1 : si64} diff --git a/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit.onnx b/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit.onnx new file mode 100644 index 000000000000..5277d1371557 Binary files /dev/null and b/test/python/onnx_importer/function_expansion/ReduceSumSquare_no_attrs.runlit.onnx differ diff --git a/tools/torch-mlir-opt/torch-mlir-opt.cpp b/tools/torch-mlir-opt/torch-mlir-opt.cpp index 2750ee2b7145..0fa392de43b3 100644 --- a/tools/torch-mlir-opt/torch-mlir-opt.cpp +++ b/tools/torch-mlir-opt/torch-mlir-opt.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -33,8 +34,13 @@ int main(int argc, char **argv) { registerStripDebugInfoPass(); registerSymbolDCEPass(); + // memref passes used in torch-backend-to-linalg-on-tensors-backend-pipeline + memref::registerExpandOpsPass(); + memref::registerResolveShapedTypeResultDimsPass(); + DialectRegistry registry; mlir::torch::registerAllDialects(registry); + mlir::torch::registerAllExtensions(registry); mlir::torch::registerOptionalInputDialects(registry); #ifdef TORCH_MLIR_ENABLE_STABLEHLO diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index a7da638bc2bf..7521ee5dbec8 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,7 @@ --f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://download.pytorch.org/whl/nightly/cpu/torchvision/ +# The nightly wheels for torchvision are regularly deleted and we don't bump the +# versions at the same pace. The wheels will therefore be cached on the xilinx +# release page, and we use this page as an additional source for the wheels. +-f https://xilinx.github.io/torch-mlir/package-index/ --pre -torchvision==0.19.0.dev20240505 +torchvision==0.22.0.dev20250310 diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index e62780ff9634..fdde0d63481a 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -64,6 +64,7 @@ gentbl_cc_library( td_file = "include/torch-mlir/Dialect/Torch/IR/TorchOps.td", deps = [ ":MLIRTorchOpsIncGenTdFiles", + "@llvm-project//mlir:BuiltinDialectTdFiles", ], ) @@ -90,6 +91,7 @@ gentbl_cc_library( cc_library( name = "TorchMLIRTorchDialectUtils", srcs = [ + "lib/Dialect/Torch/Utils/SparsityUtils.cpp", "lib/Dialect/Torch/Utils/TorchUpstream.cpp", "lib/Dialect/Torch/Utils/Utils.cpp", ], @@ -97,6 +99,7 @@ cc_library( "include/torch-mlir/Dialect/Torch/IR/TorchOps.h", "include/torch-mlir/Dialect/Torch/IR/TorchTraits.h", "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", + "include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h", "include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h", "include/torch-mlir/Dialect/Torch/Utils/Utils.h", ], @@ -108,6 +111,8 @@ cc_library( "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:SparseTensorEnums", ], ) @@ -180,6 +185,7 @@ cc_library( deps = [ ":TorchMLIRTorchBackendTypeConversion", ":TorchMLIRTorchDialect", + ":TorchMLIRTorchOnnxToTorch", ":TorchMLIRTorchPassesIncGen", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -272,6 +278,7 @@ gentbl_cc_library( [ "-gen-pass-decls", "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_TOSA", ], "include/torch-mlir/Conversion/Passes.h.inc", ), @@ -325,7 +332,11 @@ gentbl_cc_library( strip_include_prefix = "include", tbl_outs = [ ( - ["-gen-pass-decls"], + [ + "-gen-pass-decls", + "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_TOSA", + ], "include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h.inc", ), ], @@ -492,6 +503,9 @@ cc_library( "lib/Conversion/TorchToStablehlo/*.cpp", ]), hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), + defines = [ + "TORCH_MLIR_ENABLE_STABLEHLO", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen", @@ -530,6 +544,7 @@ cc_library( ], defines = [ "TORCH_MLIR_ENABLE_STABLEHLO", + "TORCH_MLIR_ENABLE_TOSA", ], strip_include_prefix = "include", deps = [ @@ -552,6 +567,10 @@ cc_library( "lib/Dialect/TorchConversion/Transforms/*.h", ]), hdrs = glob(["include/torch-mlir/Dialect/TorchConversion/Transforms/*.h"]), + defines = [ + "TORCH_MLIR_ENABLE_STABLEHLO", + "TORCH_MLIR_ENABLE_TOSA", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRTorchBackendTypeConversion", @@ -585,6 +604,9 @@ cc_library( "lib/Conversion/TorchToTosa/*.cpp", ]), hdrs = glob(["include/torch-mlir/Conversion/TorchToTosa/*.h"]), + defines = [ + "TORCH_MLIR_ENABLE_TOSA", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen", @@ -872,6 +894,7 @@ cc_library( copts = [ "-DTORCH_MLIR_ENABLE_REFBACKEND", "-DTORCH_MLIR_ENABLE_STABLEHLO", + "-DTORCH_MLIR_ENABLE_TOSA", ], strip_include_prefix = "include", deps = [ @@ -887,6 +910,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@llvm-project//mlir:TensorInferTypeOpInterfaceImpl", "@stablehlo//:linalg_passes", "@stablehlo//:stablehlo_passes", ],