diff --git a/.github/actions/docker-run/action.yml b/.github/actions/docker-run/action.yml index 08eba0c685e..5758188f6db 100644 --- a/.github/actions/docker-run/action.yml +++ b/.github/actions/docker-run/action.yml @@ -9,8 +9,8 @@ inputs: description: 'Docker image architecture' required: false default: tt-metalium/ubuntu-20.04-amd64 - docker_version: - description: 'Specify version for the Docker image tag to use.' + docker_image: + description: 'Specify Docker image to use.' required: false docker_username: description: docker login username @@ -37,11 +37,16 @@ inputs: runs: using: "composite" steps: + - name: Set docker image tag + if: ${{ inputs.docker_image }} + shell: bash + run: | + echo "TT_METAL_DOCKER_IMAGE_TAG=${{ inputs.docker_image }}" >> $GITHUB_ENV - name: Determine docker image tag + if: ${{ ! inputs.docker_image }} uses: ./.github/actions/generate-docker-tag with: image: ${{ inputs.docker_os_arch }} - version: ${{ inputs.docker_version }} - name: Set shell: bash run: | diff --git a/.github/actions/generate-docker-tag/action.yml b/.github/actions/generate-docker-tag/action.yml index 1ba5a1afd6a..a6b34761dec 100644 --- a/.github/actions/generate-docker-tag/action.yml +++ b/.github/actions/generate-docker-tag/action.yml @@ -2,39 +2,34 @@ name: "Run set of commands in Docker" description: "Run commands in docker" inputs: - run_args: - description: 'Commands to run in docker' - required: true image: description: 'Docker image to run commands in - follows os-arch format' required: false default: ubuntu-20.04-amd64 - version: - description: 'Docker image version' - required: false runs: using: "composite" steps: - - name: Determine Docker Tag - shell: bash - run: | - # If the version was provided use it, otherwise, determine what the version should be. - if [ "${{ inputs.version }}" != "" ]; then - echo "IMAGE_TAG=${{ inputs.version }}" >> $GITHUB_ENV - else - if [[ "${GITHUB_REF_NAME}" == "main" ]]; then - echo "IMAGE_TAG=latest" >> $GITHUB_ENV - else - echo "IMAGE_TAG=dev-${GITHUB_REF_NAME//\//-}" >> $GITHUB_ENV - fi - fi - - name: Determine Full Docker Image Tag + - name: Deprecation warning shell: bash run: | - echo "TT_METAL_DOCKER_IMAGE_TAG=ghcr.io/${{ github.repository }}/${{ inputs.image }}:${{ env.IMAGE_TAG }}" >> $GITHUB_ENV - echo "TT_METAL_REF_IMAGE_TAG=ghcr.io/${{ github.repository }}/${{ inputs.image }}:latest" >> $GITHUB_ENV - - name: Output Docker Image Tag + echo "::notice::[DEPRECATION] This action is deprecated. Please migrate to reading the Docker image from the pipeline." + + - name: Checkout repo + uses: actions/checkout@v3 + with: + fetch-depth: 1 + clean: false + + - name: Compute tags + id: tags shell: bash run: | - echo "IMAGE_TAG=${{ env.IMAGE_TAG }}" - echo "TT_METAL_DOCKER_IMAGE_TAG=${{ env.TT_METAL_DOCKER_IMAGE_TAG}}" + BUILD_TAG=$(cat \ + install_dependencies.sh \ + dockerfile/Dockerfile \ + tt_metal/python_env/requirements-dev.txt \ + docs/requirements-docs.txt \ + tests/sweep_framework/requirements-sweeps.txt \ + | sha1sum | cut -d' ' -f1) + echo "BUILD_TAG=$BUILD_TAG" >> $GITHUB_ENV + echo "TT_METAL_DOCKER_IMAGE_TAG=ghcr.io/${{ github.repository }}/${{ inputs.image }}:${BUILD_TAG}" >> $GITHUB_ENV diff --git a/.github/workflows/_build-wheels-impl.yaml b/.github/workflows/_build-wheels-impl.yaml deleted file mode 100644 index 70e211af017..00000000000 --- a/.github/workflows/_build-wheels-impl.yaml +++ /dev/null @@ -1,67 +0,0 @@ -name: "[internal] Python wheels build impl" - -on: - workflow_call: - inputs: - os: - required: True - type: string - from-precompiled: - required: True - default: True - type: boolean - -jobs: - build-wheel: - runs-on: ${{ inputs.os }} - steps: - - uses: tenstorrent/tt-metal/.github/actions/checkout-with-submodule-lfs@main - with: - fetch-depth: 0 - - uses: ./.github/actions/install-metal-deps - with: - os: ${{ inputs.os }} - - uses: ./.github/actions/install-metal-dev-deps - with: - os: ${{ inputs.os }} - - name: Clean up dirty files - run: git clean -f -d - - name: Set Python Version - id: python-version - run: | - if [[ "${{ inputs.os }}" == "ubuntu-20.04" ]]; then - echo "python-version=3.8" >> $GITHUB_ENV - elif [[ "${{ inputs.os }}" == "ubuntu-22.04" ]]; then - echo "python-version=3.10" >> $GITHUB_ENV - else - echo "Unsupported OS version: ${{ inputs.os }}" - exit 1 - fi - - uses: actions/setup-python@v5.0.0 - with: - cache: 'pip' - cache-dependency-path: | - tt_metal/python_env/requirements-dev.txt - pyproject.toml - - name: Install python deps for packaging - run: pip install build - - name: Use g++ as umd compiler for ubuntu 22.04 - if: ${{ inputs.os == 'ubuntu-22.04' }} - run: | - echo "DEVICE_CXX=g++" >> $GITHUB_ENV - - uses: ./.github/actions/prepare-metal-run - if: ${{ inputs.from-precompiled }} - with: - python-version: ${{ env.python-version }} - - name: Set precompiled dir for precompile builds - if: ${{ inputs.from-precompiled }} - # TT_FROM_PRECOMPILED_DIR env variable allows us to not re-run the full C++ build and instead - # rely on the artifact that was already compiled. We point it to where the repo is. - run: echo "TT_FROM_PRECOMPILED_DIR=${{ github.workspace }}" >> $GITHUB_ENV - - name: Build Python package distribution - run: python -m build - - name: Upload distribution as artifact - uses: actions/upload-artifact@v4 - with: - name: eager-dist-${{ inputs.os }}-any - path: dist/ diff --git a/.github/workflows/_produce-data.yaml b/.github/workflows/_produce-data.yaml index abc9e548df1..eca1d625272 100644 --- a/.github/workflows/_produce-data.yaml +++ b/.github/workflows/_produce-data.yaml @@ -95,6 +95,12 @@ jobs: echo "attempt-number=$attempt_number" >> "$GITHUB_OUTPUT" echo "::notice title=target-workflow-link::The workflow being analyzed is available at https://github.com/tenstorrent/tt-metal/actions/runs/$run_id/attempts/$attempt_number" + - name: Get API rate limit status + env: + GH_TOKEN: ${{ github.token }} + run: | + echo "[Info] Grabbing API rate limit status" + gh api rate_limit - name: Output auxiliary values env: GH_TOKEN: ${{ github.token }} diff --git a/.github/workflows/all-post-commit-workflows.yaml b/.github/workflows/all-post-commit-workflows.yaml index 9dbab807542..f4bd6f0dc6d 100644 --- a/.github/workflows/all-post-commit-workflows.yaml +++ b/.github/workflows/all-post-commit-workflows.yaml @@ -38,35 +38,24 @@ jobs: static-checks: uses: ./.github/workflows/all-static-checks.yaml secrets: inherit - build-wheels: - needs: build-artifact - strategy: - matrix: - # Since pre-compiled builds only run on 20.04, we can only test on 20.04 for now - # The full 22.04 flow can be tested without precompiled - os: [ubuntu-20.04] - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: true - secrets: inherit - test-wheels: - needs: build-wheels - uses: ./.github/workflows/_test-wheels-impl.yaml - with: - from-precompiled: true - secrets: inherit build-artifact: uses: ./.github/workflows/build-artifact.yaml secrets: inherit with: build-type: ${{ inputs.build-type || 'Release' }} + build-wheel: true build-artifact-profiler: uses: ./.github/workflows/build-artifact.yaml with: build-type: ${{ inputs.build-type || 'Release' }} tracy: true secrets: inherit + test-wheels: + needs: build-artifact + uses: ./.github/workflows/_test-wheels-impl.yaml + with: + from-precompiled: true + secrets: inherit # Slow Dispatch Unit Tests sd-unit-tests: needs: build-artifact @@ -85,7 +74,7 @@ jobs: runner-label: ${{ matrix.test-group.runner-label }} # Fast Dispatch Unit Tests fast-dispatch-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false @@ -102,7 +91,7 @@ jobs: runner-label: ${{ matrix.test-group.runner-label }} # TTNN FD Unit tests ttnn-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false @@ -118,7 +107,7 @@ jobs: runner-label: ${{ matrix.test-group.runner-label }} # FD Model Tests models-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false @@ -168,10 +157,21 @@ jobs: with: arch: ${{ matrix.test-group.arch }} runner-label: ${{ matrix.test-group.runner-label }} - profiler-regression: + run-profiler-regression: needs: build-artifact-profiler + strategy: + fail-fast: false + matrix: + test-group: [ + { arch: grayskull, runner-label: E150 }, + { arch: wormhole_b0, runner-label: N150 }, + { arch: wormhole_b0, runner-label: N300 }, + ] uses: ./.github/workflows/run-profiler-regression.yaml secrets: inherit + with: + arch: ${{ matrix.test-group.arch}} + runner-label: ${{ matrix.test-group.runner-label}} build-docs: needs: build-artifact uses: ./.github/workflows/docs-latest-public.yaml diff --git a/.github/workflows/all-static-checks.yaml b/.github/workflows/all-static-checks.yaml index aefed2a61c6..c46bb1b8c39 100644 --- a/.github/workflows/all-static-checks.yaml +++ b/.github/workflows/all-static-checks.yaml @@ -104,23 +104,13 @@ jobs: - uses: lukka/get-cmake@b516803a3c5fac40e2e922349d15cdebdba01e60 if: steps.changed-cmake-files.outputs.any_changed == 'true' with: - cmakeVersion: "~3.18.0" + cmakeVersion: "~3.19.0" - name: Check CMake version if: steps.changed-cmake-files.outputs.any_changed == 'true' run: cmake --version - - name: Install LLVM and Clang + - name: Install Build Dependencies if: steps.changed-cmake-files.outputs.any_changed == 'true' - run: | - wget https://apt.llvm.org/llvm.sh - chmod u+x llvm.sh - sudo ./llvm.sh 17 - - name: Install deps - if: steps.changed-cmake-files.outputs.any_changed == 'true' - env: - DEBIAN_FRONTEND: noninteractive - run: | - sudo apt update - sudo xargs -a scripts/docker/requirements-22.04.txt apt install -y --no-install-recommends + run: sudo ./install_dependencies.sh --mode build - name: Check CMake compatibility if: steps.changed-cmake-files.outputs.any_changed == 'true' env: diff --git a/.github/workflows/blackhole-post-commit.yaml b/.github/workflows/blackhole-post-commit.yaml index a33fa4631e1..f8e4e25d429 100644 --- a/.github/workflows/blackhole-post-commit.yaml +++ b/.github/workflows/blackhole-post-commit.yaml @@ -15,6 +15,15 @@ on: required: true type: string default: 'BH' + build-type: + required: false + default: Release + type: choice + options: + - Release + - Debug + - RelWithDebInfo + - CI schedule: - cron: "0 */2 * * *" # Pause this since not enough runners to support every commit to main @@ -37,13 +46,23 @@ jobs: uses: ./.github/workflows/build-artifact.yaml secrets: inherit with: + build-type: ${{ inputs.build-type || 'Release' }} + build-wheel: true version: "22.04" - build-wheels: - needs: build-artifact - uses: ./.github/workflows/_build-wheels-impl.yaml + build-artifact-profiler: + uses: ./.github/workflows/build-artifact.yaml + secrets: inherit with: - os: "ubuntu-22.04" - from-precompiled: true + build-type: ${{ inputs.build-type || 'Release' }} + tracy: true + version: "20.04" + run-profiler-regression: + needs: build-artifact-profiler + uses: ./.github/workflows/run-profiler-regression.yaml + secrets: inherit + with: + arch: "blackhole" + runner-label: ${{ inputs.runner-label || 'BH' }} umd-unit-tests: secrets: inherit uses: ./.github/workflows/umd-unit-tests.yaml @@ -60,7 +79,7 @@ jobs: timeout: 30 os: "ubuntu-22.04" fd-unit-tests: - needs: build-wheels + needs: build-artifact uses: ./.github/workflows/fast-dispatch-build-and-unit-tests.yaml secrets: inherit with: diff --git a/.github/workflows/build-and-test-wheels.yaml b/.github/workflows/build-and-test-wheels.yaml index d21c08d1f76..27494489a25 100644 --- a/.github/workflows/build-and-test-wheels.yaml +++ b/.github/workflows/build-and-test-wheels.yaml @@ -15,20 +15,8 @@ jobs: if: ${{ github.event_name == 'workflow_dispatch' && inputs.from-precompiled }} uses: ./.github/workflows/build-artifact.yaml secrets: inherit - build-wheels: - needs: build-artifact - if: ${{ always() }} - strategy: - matrix: - # Since pre-compiled builds only run on 20.04, we can only test on 20.04 for now - # The full 22.04 flow can be tested without precompiled - os: ${{ fromJson((github.event_name == 'schedule' || inputs.from-precompiled) && '["ubuntu-20.04"]' || '["ubuntu-20.04", "ubuntu-22.04"]') }} - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: ${{ inputs.from-precompiled }} test-wheels: - needs: build-wheels + needs: build-artifact if: ${{ always() }} uses: ./.github/workflows/_test-wheels-impl.yaml with: diff --git a/.github/workflows/build-artifact.yaml b/.github/workflows/build-artifact.yaml index 9ddc81d266f..c9fed1b5405 100644 --- a/.github/workflows/build-artifact.yaml +++ b/.github/workflows/build-artifact.yaml @@ -12,6 +12,11 @@ on: type: boolean default: false description: "Build with tracy enabled" + build-wheel: + required: false + type: boolean + default: false + description: "Build Python Wheel" distro: required: false type: string @@ -42,6 +47,13 @@ on: required: false type: boolean default: false + outputs: + ci-build-docker-image: + description: "Docker tag for the CI Build Docker image for building TT-Metalium et al" + value: ${{ jobs.build-docker-image.outputs.ci-build-tag }} + #ci-test-docker-image: + # description: "Docker tag for the CI Test Docker image for testing TT-Metalium et al" + # value: ${{ jobs.build-docker-image.outputs.ci-test-tag }} workflow_dispatch: @@ -55,6 +67,11 @@ on: type: boolean default: false description: "Build with tracy enabled" + build-wheel: + required: false + type: boolean + default: false + description: "Build Python Wheel" distro: required: false type: string @@ -87,13 +104,29 @@ jobs: name: "🛠️ Build ${{ inputs.build-type }} ${{ inputs.distro }} ${{ inputs.version }}" needs: build-docker-image timeout-minutes: 30 - env: - SILENT: 0 - VERBOSE: 1 - IMAGE_PARAMS: "${{ inputs.distro }}-${{ inputs.version }}-${{ inputs.architecture }}" runs-on: - build - in-service + container: + image: ${{ needs.build-docker-image.outputs.ci-build-tag }} + env: + CCACHE_TEMPDIR: /tmp/ccache + CARGO_HOME: /tmp/.cargo + TT_FROM_PRECOMPILED_DIR: /work + volumes: + - ${{ github.workspace }}/docker-job:/work # Subdir to workaround https://github.com/actions/runner/issues/691 + - /home/ubuntu/.ccache-ci:/github/home/.ccache # HOME is hardcoded for no clear reason: https://github.com/actions/runner/issues/863 + - /mnt/MLPerf/ccache:/mnt/MLPerf/ccache + # Group 1457 is for the shared ccache drive + # tmpfs is for efficiency + options: > + --group-add 1457 + --tmpfs /tmp + defaults: + run: + shell: bash + working-directory: /work # https://github.com/actions/runner/issues/878 + steps: - name: Verify ccache availability shell: bash @@ -102,85 +135,98 @@ jobs: echo "::error title=ccache-mlperf-not-mounted::NFS drive is not mounted; build machine not properly provisioned." exit 1 fi - if [ ! -d "$HOME/.ccache-ci" ]; then + if [ ! -d "$HOME/.ccache" ]; then echo "::error title=ccache-not-provisioned::Ccache is not properly provisioned." exit 1 fi - - uses: tenstorrent/tt-metal/.github/actions/checkout-with-submodule-lfs@main - - name: Set up dynamic env vars for build + + - name: ⬇️ Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + path: docker-job # Here be dragons; keep it scoped to our desired volume, yet must be under github.workspace and be sure to clean up at the end + + - name: Sanity check run: | - echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV - echo "RUNNER_UID=$(id -u)" >> $GITHUB_ENV - echo "RUNNER_GID=$(id -g)" >> $GITHUB_ENV - - name: Update submodules + set -eu # basic shell hygiene + if find . -maxdepth 1 -type d -name 'build*' -print -quit | grep -q .; then + echo "!!! ALERT !!! This should never happen, but does explain an issue we've been hunting. Please send a link to this job to Metal Infra. kthxbye." + exit 42 + fi + + - name: Create ccache tmpdir run: | - git submodule update --init --recursive - - name: Generate docker tag - id: generate-docker-tag - uses: ./.github/actions/generate-docker-tag - with: - image: tt-metalium/${{ env.IMAGE_PARAMS }} - - name: Docker login - uses: docker/login-action@v3 - with: - registry: https://ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Pull docker image - run: docker pull ${{ env.TT_METAL_DOCKER_IMAGE_TAG }} - - name: Build tt-metal and libs - uses: tenstorrent/docker-run-action@v5 - with: - image: ${{ env.TT_METAL_DOCKER_IMAGE_TAG }} - options: | - --rm - --tmpfs /tmp - -u ${{ env.RUNNER_UID }}:${{ env.RUNNER_GID }} - --group-add 1457 - -v ${{ github.workspace }}:${{ github.workspace }} - -v /etc/passwd:/etc/passwd:ro - -v /etc/shadow:/etc/shadow:ro - -v /etc/bashrc:/etc/bashrc:ro - -v /home/ubuntu/.ccache-ci:/home/ubuntu/.ccache - -v /mnt/MLPerf/ccache:/mnt/MLPerf/ccache - -e CARGO_HOME=${{ github.workspace }}/.cargo - -w ${{ github.workspace }} - run: | - set -eu # basic shell hygiene - - # /tmp is a tmpfs; more efficient than persisted storage - mkdir -p /tmp/ccache - export CCACHE_TEMPDIR=/tmp/ccache - - # Zero out the stats so we can see how we did this build - # NOTE: may be inaccurate if we have >1 build runner on the same machine, using the same local cache - ccache -z - - args_fixme=$([ "${{ inputs.skip-tt-train }}" = "true" ] && echo "--build-metal-tests --build-ttnn-tests --build-programming-examples" || echo "--build-all") - echo "Args: ${args_fixme}" - build_command="./build_metal.sh --build-type ${{ inputs.build-type }} --toolchain-path ${{ inputs.toolchain }} ${args_fixme} --enable-ccache" - echo "Build tracy: ${{ inputs.tracy }}" - if [ "${{ inputs.tracy }}" = "true" ]; then - build_command="$build_command --enable-profiler" - fi - - [ -n "$(find . -maxdepth 1 -type d -name 'build*' -print -quit)" ] && - { echo "!!! ALERT !!! This should never happen, but does explain an issue we've been hunting. Please send a link to this job to Metal Infra. kthxbye."; exit 1; } - - nice -n 19 $build_command - ccache -s > build/ccache.stats - - name: Publish Ccache summary + mkdir -p /tmp/ccache + + - name: Prepare ccache summary + run: | + # Zero out the stats so we can see how we did this build + # NOTE: may be inaccurate if we have >1 build runner on the same machine, using the same local cache + ccache -z + + - name: 🔧 CMake configure + run: | + set -eu # basic shell hygiene + + args_fixme=$([ "${{ inputs.skip-tt-train }}" = "true" ] && echo "--build-metal-tests --build-ttnn-tests --build-programming-examples" || echo "--build-all") + echo "Args: ${args_fixme}" + build_command="./build_metal.sh --build-type ${{ inputs.build-type }} --toolchain-path ${{ inputs.toolchain }} ${args_fixme} --enable-ccache --configure-only" + echo "Build tracy: ${{ inputs.tracy }}" + if [ "${{ inputs.tracy }}" = "true" ]; then + build_command="$build_command --enable-profiler" + fi + + nice -n 19 $build_command + + - name: 🛠️ Compile + run: | + # --target install is for the tarball that should get replaced by proper packaging later + nice -19 cmake --build build --target install + + - name: 📦 Package + if: false # Packaging coming later + run: | + nice -19 cmake --build $build_dir --target package + + - name: 🐍 Build wheel + if: ${{ inputs.build-wheel }} + run: | + nice -n 19 python3 -m build + + - name: Publish ccache summary run: | echo '## CCache Summary' >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY - cat build/ccache.stats >> $GITHUB_STEP_SUMMARY + ccache -s >> $GITHUB_STEP_SUMMARY echo '```' >> $GITHUB_STEP_SUMMARY + + - name: ☁️ Upload wheel + if: ${{ inputs.build-wheel }} + uses: actions/upload-artifact@v4 + with: + name: eager-dist-${{ inputs.distro }}-${{ inputs.version }}-any + path: /work/dist/ + if-no-files-found: error + - name: 'Tar files' if: ${{ inputs.publish-artifact }} - run: tar -cvhf ttm_any.tar ttnn/ttnn/*.so build/lib ttnn/ttnn/*.so build/programming_examples build/test build/tools build/tt-train data runtime - - name: 'Upload Artifact' + run: tar -cvhf /work/ttm_any.tar ttnn/ttnn/*.so build/lib ttnn/ttnn/*.so build/programming_examples build/test build/tools build/tt-train data runtime + + - name: ☁️ Upload tarball if: ${{ inputs.publish-artifact }} uses: actions/upload-artifact@v4 with: name: TTMetal_build_any${{ (inputs.tracy && '_profiler') || '' }} - path: ttm_any.tar + path: /work/ttm_any.tar + if-no-files-found: error + + - name: Cleanup + if: always() + run: | + # We are forced to checkout the repo into a subdir of the host's workdir; this pollutes the host + # with root-owned files. Be sure to clean up after ourselves in case we're on a non-ephemeral runner. + echo "pre rm" + ls -al /__w/tt-metal/tt-metal + rm -rf /__w/tt-metal/tt-metal/docker-job + echo "post rm" + ls -al /__w/tt-metal/tt-metal diff --git a/.github/workflows/build-docker-artifact.yaml b/.github/workflows/build-docker-artifact.yaml index 1110805e2a2..b4fec5670b4 100644 --- a/.github/workflows/build-docker-artifact.yaml +++ b/.github/workflows/build-docker-artifact.yaml @@ -15,6 +15,13 @@ on: required: false type: string default: "amd64" + outputs: + ci-build-tag: + description: "Docker tag for the CI Build Docker image for building TT-Metalium et al" + value: ${{ jobs.check-docker-images.outputs.ci-build-tag }} + #ci-test-tag: + # description: "Docker tag for the CI Test Docker image for testing TT-Metalium et al" + # value: ${{ jobs.check-docker-images.outputs.ci-test-tag }} workflow_dispatch: inputs: distro: @@ -37,62 +44,76 @@ on: default: "amd64" options: - "amd64" + +env: + IMAGE_NAME: ${{ inputs.distro }}-${{ inputs.version }}-${{ inputs.architecture }} + jobs: + check-docker-images: + runs-on: ubuntu-latest + outputs: + ci-build-exists: ${{ steps.images.outputs.ci-build-exists }} + ci-build-tag: ${{ steps.tags.outputs.ci-build-tag }} + # ci-test-exists: ${{ steps.images.outputs.ci-test-exists }} + # ci-test-tag: ${{ steps.tags.outputs.ci-test-tag }} + steps: + - name: Checkout repo + uses: actions/checkout@v3 + with: + fetch-depth: 1 + + - name: Compute tags + id: tags + run: | + BUILD_TAG=$(cat \ + install_dependencies.sh \ + dockerfile/Dockerfile \ + tt_metal/python_env/requirements-dev.txt \ + docs/requirements-docs.txt \ + tests/sweep_framework/requirements-sweeps.txt \ + | sha1sum | cut -d' ' -f1) + echo "ci-build-tag=ghcr.io/${{ github.repository }}/tt-metalium/${{ env.IMAGE_NAME }}:${BUILD_TAG}" >> $GITHUB_OUTPUT + + # TODO: When we have multiple Docker images, do something like this: + # TEST_TAG=$(cat tt_metal/python_env/requirements-dev.txt pyproject.toml | sha1sum | cut -d' ' -f1) + # echo "ci-test-tag=ghcr.io/${{ github.repository }}/tt-metalium/${{ env.IMAGE_NAME }}:${TEST_TAG}" >> $GITHUB_OUTPUT + + - name: Query images exist + id: images + run: | + if docker manifest inspect ${{ steps.tags.outputs.ci-build-tag }} > /dev/null 2>&1; then + echo "${{ steps.tags.outputs.ci-build-tag }} exists" + echo "ci-build-exists=true" >> $GITHUB_OUTPUT + else + echo "${{ steps.tags.outputs.ci-build-tag }} does not exist" + echo "ci-build-exists=false" >> $GITHUB_OUTPUT + fi + + build-docker-image: - name: "🐳️ Build ${{ inputs.distro }} ${{inputs.version }} image" + name: "🐳️ Build image" + needs: check-docker-images + if: needs.check-docker-images.outputs.ci-build-exists != 'true' timeout-minutes: 30 - env: - CONFIG: ci - SILENT: 0 - VERBOSE: 1 - IMAGE_PARAMS: "${{ inputs.distro }}-${{ inputs.version }}-${{ inputs.architecture }}" - IMAGE: tt-metalium/ubuntu-20.04-amd64 - DOCKERFILE: ubuntu-20.04-amd64 runs-on: - build-docker - in-service steps: - uses: tenstorrent/tt-metal/.github/actions/checkout-with-submodule-lfs@main - with: - fetch-depth: 0 - name: Login to GitHub Container Registry uses: docker/login-action@v3 with: registry: https://ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Get all test, doc and src files that have changed - id: changed-files-specific - uses: tj-actions/changed-files@v45 - with: - files: | - dockerfile/**.Dockerfile - scripts/docker/install_test_deps.sh - scripts/docker/requirements* - pyproject.toml - tt_metal/python_env/requirements-dev.txt - base_sha: 'main' - - name: Determine docker image tag - uses: ./.github/actions/generate-docker-tag - with: - image: tt-metalium/${{ env.IMAGE_PARAMS }} - name: Build Docker image and push to GHCR - if: steps.changed-files-specific.outputs.any_changed == 'true' uses: docker/build-push-action@v6 with: context: ${{ github.workspace }} - file: dockerfile/${{ env.IMAGE_PARAMS }}.Dockerfile + file: dockerfile/Dockerfile + target: dev push: true - tags: ${{ env.TT_METAL_DOCKER_IMAGE_TAG}} + tags: ${{ needs.check-docker-images.outputs.ci-build-tag }} build-args: UBUNTU_VERSION=${{ inputs.version }} - cache-from: type=registry,ref=${{ env.TT_METAL_REF_IMAGE_TAG }} cache-to: type=inline pull: true - - name: Tag Docker main image as current image - if: steps.changed-files-specific.outputs.any_changed != 'true' - run: | - docker pull ghcr.io/${{ github.repository }}/tt-metalium/${{ env.IMAGE_PARAMS }}:latest - docker tag ghcr.io/${{ github.repository }}/tt-metalium/${{ env.IMAGE_PARAMS }}:latest ${{ env.TT_METAL_DOCKER_IMAGE_TAG}} - - name: Push Docker image to GitHub Container Registry - run: | - docker push ${{ env.TT_METAL_DOCKER_IMAGE_TAG }} diff --git a/.github/workflows/code-analysis.yaml b/.github/workflows/code-analysis.yaml index a88cf647691..b78af4fb6c1 100644 --- a/.github/workflows/code-analysis.yaml +++ b/.github/workflows/code-analysis.yaml @@ -67,6 +67,8 @@ jobs: echo "::error title=ccache-not-provisioned::Ccache is not properly provisioned." exit 1 fi + - name: Check out repo + uses: actions/checkout@v4 - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/cpp-post-commit.yaml b/.github/workflows/cpp-post-commit.yaml index 1177c1f6efe..0feaa3b80cb 100644 --- a/.github/workflows/cpp-post-commit.yaml +++ b/.github/workflows/cpp-post-commit.yaml @@ -64,6 +64,7 @@ jobs: {name: stl, cmd: "./build/test/tt_metal/unit_tests_stl"}, {name: distributed, cmd: "./build/test/tt_metal/distributed/distributed_unit_tests_${{ inputs.arch }} --gtest_filter=MeshDeviceSuite.*"}, + {name: lightmetal, cmd: "./build/test/tt_metal/unit_tests_lightmetal"}, {name: dispatch multicmd queue, cmd: "TT_METAL_GTEST_NUM_HW_CQS=2 ./build/test/tt_metal/unit_tests_dispatch_${{ inputs.arch }} --gtest_filter=MultiCommandQueue*Fixture.*"}, {name: ttnn cpp unit tests, cmd: ./build/test/ttnn/unit_tests_ttnn}, diff --git a/.github/workflows/docs-latest-public.yaml b/.github/workflows/docs-latest-public.yaml index 2afe136086e..85e76a877c7 100644 --- a/.github/workflows/docs-latest-public.yaml +++ b/.github/workflows/docs-latest-public.yaml @@ -73,3 +73,13 @@ jobs: if: ${{ github.ref == 'refs/heads/main' }} id: deployment uses: actions/deploy-pages@v4.0.4 + - name: Delete artifact if deployment failed + # When the deployment API call fails, the artifacts are not cleaned up correctly + # and the next attempt (!) run will cause an error. + # See more: + # https://github.com/tenstorrent/tt-metal/issues/17623 + if: ${{ failure() }} + uses: geekyeggo/delete-artifact@v5 + continue-on-error: true + with: + name: github-pages diff --git a/.github/workflows/fast-dispatch-build-and-unit-tests-wrapper.yaml b/.github/workflows/fast-dispatch-build-and-unit-tests-wrapper.yaml index c3e1c4f3879..cfbaf686cd5 100644 --- a/.github/workflows/fast-dispatch-build-and-unit-tests-wrapper.yaml +++ b/.github/workflows/fast-dispatch-build-and-unit-tests-wrapper.yaml @@ -11,21 +11,9 @@ jobs: needs: build-docker-artifact uses: ./.github/workflows/build-artifact.yaml secrets: inherit - build-wheels: - needs: build-artifact - strategy: - matrix: - # Since pre-compiled builds only run on 20.04, we can only test on 20.04 for now - # The full 22.04 flow can be tested without precompiled - os: [ubuntu-20.04] - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: true - secrets: inherit # FD Unit Tests fast-dispatch-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false @@ -41,7 +29,7 @@ jobs: runner-label: ${{ matrix.test-group.runner-label}} # TTNN FD Unit tests ttnn-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false @@ -58,7 +46,7 @@ jobs: # FD Model Tests models-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false @@ -75,7 +63,7 @@ jobs: # FD C++ Unit Tests cpp-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false diff --git a/.github/workflows/fast-dispatch-build-and-unit-tests.yaml b/.github/workflows/fast-dispatch-build-and-unit-tests.yaml index 8042f7cd7ca..125a0cf4f41 100644 --- a/.github/workflows/fast-dispatch-build-and-unit-tests.yaml +++ b/.github/workflows/fast-dispatch-build-and-unit-tests.yaml @@ -59,7 +59,6 @@ jobs: {name: eager unit tests 5, cmd: pytest tests/tt_eager/python_api_testing/unit_testing/ -xvvv --splits 7 --group 5 }, {name: eager unit tests 6, cmd: pytest tests/tt_eager/python_api_testing/unit_testing/ -xvvv --splits 7 --group 6 }, {name: eager unit tests 7, cmd: pytest tests/tt_eager/python_api_testing/unit_testing/ -xvvv --splits 7 --group 7 }, - {name: eager trace tests, cmd: pytest tests/tt_eager/python_api_testing/trace_testing/ -xvvv}, {name: sweep, cmd: pytest tests/tt_eager/python_api_testing/sweep_tests/pytests/ -xvvv}, ] name: ${{ matrix.test-group.name }} ${{ inputs.arch }} ${{ inputs.runner-label }} diff --git a/.github/workflows/metal-api-surface.yaml b/.github/workflows/metal-api-surface.yaml index a9e87c58f83..2a3376c1154 100644 --- a/.github/workflows/metal-api-surface.yaml +++ b/.github/workflows/metal-api-surface.yaml @@ -58,7 +58,7 @@ jobs: payload: | { "text": "\nTT_METAL_API_SURFACE:\ndate: ${{ env.DATE }} \nnum_files: ${{ env.NUM_FILES }} \nnum_types: ${{ env.NUM_TYPES }} \nnum_methods: ${{ env.NUM_METHODS }}", - "owner": "U0704SWRY9H" + "owner": "U07J3K6KS1K" } env: SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} diff --git a/.github/workflows/models-post-commit-wrapper.yaml b/.github/workflows/models-post-commit-wrapper.yaml index ccdccc25a4a..be31f38a4ce 100644 --- a/.github/workflows/models-post-commit-wrapper.yaml +++ b/.github/workflows/models-post-commit-wrapper.yaml @@ -8,27 +8,11 @@ jobs: static-checks: uses: ./.github/workflows/all-static-checks.yaml secrets: inherit - build-docker-artifact: - uses: ./.github/workflows/build-docker-artifact.yaml - secrets: inherit build-artifact: - needs: build-docker-artifact uses: ./.github/workflows/build-artifact.yaml secrets: inherit - build-wheels: - needs: build-artifact - strategy: - matrix: - # Since pre-compiled builds only run on 20.04, we can only test on 20.04 for now - # The full 22.04 flow can be tested without precompiled - os: [ubuntu-20.04] - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: true - secrets: inherit models-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false diff --git a/.github/workflows/package-and-release.yaml b/.github/workflows/package-and-release.yaml index 35e98549167..c5dfdcb0f50 100644 --- a/.github/workflows/package-and-release.yaml +++ b/.github/workflows/package-and-release.yaml @@ -118,21 +118,12 @@ jobs: with: name: release-notes path: RELEASE_NOTES.txt - build-wheels: - needs: create-tag - strategy: - matrix: - os: [ubuntu-20.04] - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: false # Candidate for breaking up create-and-upload-draft-release: needs: [ create-tag, create-release-notes, - build-wheels, + build-artifact, ] strategy: matrix: @@ -186,12 +177,14 @@ jobs: fail_on_unmatched_files: true create-docker-release-image: needs: [ + build-artifact, create-tag, create-and-upload-draft-release ] uses: ./.github/workflows/publish-release-image.yaml secrets: inherit with: + base-image: ${{ needs.build-artifact.outputs.ci-build-docker-image }} version: ${{ needs.create-tag.outputs.version }} is_major_version: ${{ needs.get-params.outputs.is-release-candidate !='true' && needs.get-params.outputs.should-create-release == 'true' }} release-docs: diff --git a/.github/workflows/pr-gate.yaml b/.github/workflows/pr-gate.yaml index 00a28443888..d113f88207c 100644 --- a/.github/workflows/pr-gate.yaml +++ b/.github/workflows/pr-gate.yaml @@ -33,7 +33,8 @@ concurrency: cancel-in-progress: true jobs: - build-artifact: + pr-gate-build: + name: Build if: github.event_name != 'pull_request' || !github.event.pull_request.draft uses: ./.github/workflows/build-artifact.yaml with: diff --git a/.github/workflows/publish-release-image-wrapper.yaml b/.github/workflows/publish-release-image-wrapper.yaml index 45ff119d4d4..371732d2f80 100644 --- a/.github/workflows/publish-release-image-wrapper.yaml +++ b/.github/workflows/publish-release-image-wrapper.yaml @@ -6,21 +6,11 @@ jobs: build-artifact: uses: ./.github/workflows/build-artifact.yaml secrets: inherit - build-wheels: - needs: build-artifact - strategy: - matrix: - # Since pre-compiled builds only run on 20.04, we can only test on 20.04 for now - # The full 22.04 flow can be tested without precompiled - os: [ubuntu-20.04] - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: true publish-release-image: - needs: build-wheels + needs: build-artifact uses: ./.github/workflows/publish-release-image.yaml secrets: inherit with: + base-image: ${{ needs.build-artifact.outputs.ci-build-docker-image }} version: dev-${GITHUB_REF_NAME//\//-} is_major_version: false diff --git a/.github/workflows/publish-release-image.yaml b/.github/workflows/publish-release-image.yaml index 4548f43bb2a..6eae630fef5 100644 --- a/.github/workflows/publish-release-image.yaml +++ b/.github/workflows/publish-release-image.yaml @@ -3,6 +3,10 @@ name: "[internal] Create and Publish Release Docker Image" on: workflow_call: inputs: + base-image: + description: "Base image to build on top of" + required: true + type: string version: required: true type: string @@ -51,10 +55,11 @@ jobs: push: true build-args: | WHEEL_FILENAME=${{ env.WHEEL_FILENAME }} - BASE_IMAGE_NAME=tt-metalium/${{ matrix.os }}-amd64 + BASE_IMAGE=${{ inputs.base-image }} tags: ${{ env.TAG_NAME }} context: . - file: dockerfile/release.Dockerfile + file: dockerfile/Dockerfile + target: release smoke-test-docker-image: needs: create-docker-release-image strategy: @@ -85,10 +90,11 @@ jobs: timeout-minutes: ${{ inputs.timeout }} uses: ./.github/actions/docker-run with: - docker_os_arch: tt-metalium-${{ matrix.os }}-amd64-release - docker_version: ${{ inputs.version }} + docker_image: ghcr.io/${{ github.repository }}/tt-metalium-${{ matrix.os }}-amd64-release:${{ inputs.version }} docker_password: ${{ secrets.GITHUB_TOKEN }} run_args: | + pip install pytest + export PATH="$(pwd)/.local/bin:$PATH" ${{ matrix.test_group.cmd }} tag-docker-image-as-latest: needs: [smoke-test-docker-image, create-docker-release-image] diff --git a/.github/workflows/run-profiler-regression-wrapper.yaml b/.github/workflows/run-profiler-regression-wrapper.yaml index 915e8580082..52248542b21 100644 --- a/.github/workflows/run-profiler-regression-wrapper.yaml +++ b/.github/workflows/run-profiler-regression-wrapper.yaml @@ -12,5 +12,16 @@ jobs: secrets: inherit run-profiler-regression: needs: build-artifact-profiler + strategy: + fail-fast: false + matrix: + test-group: [ + { arch: grayskull, runner-label: E150 }, + { arch: wormhole_b0, runner-label: N150 }, + { arch: wormhole_b0, runner-label: N300 }, + ] uses: ./.github/workflows/run-profiler-regression.yaml secrets: inherit + with: + arch: ${{ matrix.test-group.arch}} + runner-label: ${{ matrix.test-group.runner-label}} diff --git a/.github/workflows/run-profiler-regression.yaml b/.github/workflows/run-profiler-regression.yaml index adbef02dea0..4cbc4224b45 100644 --- a/.github/workflows/run-profiler-regression.yaml +++ b/.github/workflows/run-profiler-regression.yaml @@ -2,6 +2,46 @@ name: "[internal] metal - Run profiler regression impl" on: workflow_call: + inputs: + arch: + required: true + type: string + runner-label: + required: true + type: string + timeout: + required: false + type: number + default: 35 + os: + required: false + type: string + default: "ubuntu-20.04" + workflow_dispatch: + inputs: + arch: + required: true + type: choice + options: + - grayskull + - wormhole_b0 + - blackhole + runner-label: + required: true + type: choice + options: + - E150 + - N150 + - N300 + - BH + timeout: + required: false + type: number + default: 35 + os: + required: false + type: string + default: "ubuntu-20.04" jobs: profiler-regression: @@ -9,21 +49,14 @@ jobs: # Do not fail-fast because we need to ensure all tests go to completion # so we try not to get hanging machines fail-fast: false - matrix: - runner-info: [ - # E150 - {arch: grayskull, runs-on: ["cloud-virtual-machine", "E150", "in-service"], name: E150}, - # N150 - {arch: wormhole_b0, runs-on: ["cloud-virtual-machine", "N150", "in-service"], name: N150}, - # N300 - {arch: wormhole_b0, runs-on: ["cloud-virtual-machine", "N300", "in-service"], name: N300}, - ] env: - TT_METAL_ENV: ${{ vars.TT_METAL_ENV }} - ARCH_NAME: ${{ matrix.runner-info.arch }} + ARCH_NAME: ${{ inputs.arch }} LOGURU_LEVEL: INFO LD_LIBRARY_PATH: ${{ github.workspace }}/build/lib - runs-on: ${{ matrix.runner-info.runs-on }} + runs-on: + - ${{ inputs.runner-label }} + - cloud-virtual-machine + - in-service steps: - uses: tenstorrent/tt-metal/.github/actions/checkout-with-submodule-lfs@main - name: Set up dynamic env vars for build @@ -36,7 +69,7 @@ jobs: run: tar -xvf ttm_any.tar - uses: ./.github/actions/install-python-deps - name: Run profiler regression tests - timeout-minutes: 30 + timeout-minutes: ${{ inputs.timeout }} run: | ./tests/scripts/run_profiler_regressions.sh - uses: ./.github/actions/slack-report diff --git a/.github/workflows/tt-metal-l2-nightly.yaml b/.github/workflows/tt-metal-l2-nightly.yaml new file mode 100644 index 00000000000..bbbbb618607 --- /dev/null +++ b/.github/workflows/tt-metal-l2-nightly.yaml @@ -0,0 +1,83 @@ +name: "[internal] tt-metal l2 nightly tests" + +on: + workflow_call: + inputs: + arch: + required: true + type: string + runner-label: + required: true + type: string + timeout: + required: false + type: number + default: 45 + workflow_dispatch: + inputs: + arch: + required: true + type: choice + options: + - grayskull + - wormhole_b0 + - blackhole + runner-label: + required: true + type: choice + options: + - E150 + - N150 + - N300 + - BH + timeout: + required: false + type: number + default: 45 + schedule: + - cron: "0 22 * * *" + +jobs: + build: + uses: ./.github/workflows/build-artifact.yaml + secrets: inherit + with: + build-wheel: true + test: + needs: build + strategy: + fail-fast: false + matrix: + os: ["ubuntu-20.04"] + test-group: + - name: ttnn example tests + cmd: ./tests/scripts/run_ttnn_examples.sh + name: ${{ matrix.test-group.name }} ${{ inputs.arch }} ${{ inputs.runner-label }} + env: + LOGURU_LEVEL: INFO + runs-on: + - ${{ inputs.runner-label }} + - "in-service" + steps: + - uses: tenstorrent/tt-metal/.github/actions/checkout-with-submodule-lfs@main + - uses: actions/download-artifact@v4 + with: + name: eager-dist-${{ matrix.os }}-any + - name: ${{ matrix.test-group.name }} tests + timeout-minutes: ${{ inputs.timeout }} + uses: ./.github/actions/docker-run + with: + docker_username: ${{ github.actor }} + docker_password: ${{ secrets.GITHUB_TOKEN }} + docker_opts: | + -e ARCH_NAME=${{ inputs.arch }} + run_args: | + WHEEL_FILENAME=$(ls -1 *.whl) + pip3 install --user $WHEEL_FILENAME + ${{ matrix.test-group.cmd }} + + - uses: ./.github/actions/slack-report + if: ${{ failure() }} + with: + slack_webhook_url: ${{ secrets.SLACK_WEBHOOK_URL }} + owner: U07HTBQPHFG # Bryan Keith diff --git a/.github/workflows/ttnn-post-commit-wrapper.yaml b/.github/workflows/ttnn-post-commit-wrapper.yaml index 324f6582f5d..74a5c9575ea 100644 --- a/.github/workflows/ttnn-post-commit-wrapper.yaml +++ b/.github/workflows/ttnn-post-commit-wrapper.yaml @@ -11,20 +11,8 @@ jobs: build-artifact: uses: ./.github/workflows/build-artifact.yaml secrets: inherit - build-wheels: - needs: build-artifact - strategy: - matrix: - # Since pre-compiled builds only run on 20.04, we can only test on 20.04 for now - # The full 22.04 flow can be tested without precompiled - os: [ubuntu-20.04] - uses: ./.github/workflows/_build-wheels-impl.yaml - with: - os: ${{ matrix.os }} - from-precompiled: true - secrets: inherit ttnn-unit-tests: - needs: build-wheels + needs: build-artifact secrets: inherit strategy: fail-fast: false diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a36f8d106d..a26b956890a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.18...3.30) +cmake_minimum_required(VERSION 3.19...3.30) # Sanity check, forgetting to clone submodules is a common omission and results in a poor error message if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/third_party/umd/CMakeLists.txt") @@ -111,6 +111,7 @@ message(STATUS "Build TT METAL Tests: ${TT_METAL_BUILD_TESTS}") message(STATUS "Build TTNN Tests: ${TTNN_BUILD_TESTS}") message(STATUS "Build with Unity builds: ${TT_UNITY_BUILDS}") message(STATUS "Build with Shared TTNN Sublibraries: ${ENABLE_TTNN_SHARED_SUBLIBS}") +message(STATUS "Build with LightMetal Trace Enabled: ${TT_ENABLE_LIGHT_METAL_TRACE}") ############################################################################################################################ @@ -145,7 +146,6 @@ unset(SANITIZER_ENABLED) ############################################################################################################################ # Find all required libraries to build ############################################################################################################################ -set(ENV{CPM_SOURCE_CACHE} "${PROJECT_SOURCE_DIR}/.cpmcache") include(CPM) if(CMAKE_VERSION VERSION_LESS 3.25) # FIXME(14681): `SYSTEM` was introduced in v3.25; remove this when we can require v3.25 @@ -232,6 +232,13 @@ add_link_options( "$<$:-fsanitize=undefined>" ) +# Planned to be temporary, remove later. +if(TT_ENABLE_LIGHT_METAL_TRACE) + add_compile_definitions(TT_ENABLE_LIGHT_METAL_TRACE=1) +else() + add_compile_definitions(TT_ENABLE_LIGHT_METAL_TRACE=0) +endif() + if(ENABLE_CODE_TIMERS) add_compile_definitions(TT_ENABLE_CODE_TIMERS) endif() diff --git a/CODEOWNERS b/CODEOWNERS index 025c67e398d..6e2fffa151b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -92,6 +92,9 @@ tt_metal/hostdevcommon/profiler_common.h @mo-tenstorrent docs/source/performance_measurement_tools/profiler.rst @mo-tenstorrent tt-metal/tt_metal/programming_examples/profiler @mo-tenstorrent +# Metalium - flatbuffer schemas +tt_metal/impl/flatbuffer/ @kmabeeTT @nsmithtt @omilyutin-tt + # test scripts tests/scripts/run_profiler_regressions.sh @mo-tenstorrent @tenstorrent/metalium-developers-infra tests/scripts/run_performance.sh @tenstorrent/metalium-developers-infra diff --git a/METALIUM_GUIDE.md b/METALIUM_GUIDE.md index 28042f6589e..5ddc05de55e 100644 --- a/METALIUM_GUIDE.md +++ b/METALIUM_GUIDE.md @@ -372,7 +372,7 @@ void MAIN { constexpr auto cb_out0 = tt::CBIndex::c_16; binary_op_init_common(cb_in0, cb_in1, cb_out0); - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); for(uint32_t block = 0; block < per_core_block_cnt; ++block) { diff --git a/README.md b/README.md index 6394d5a6a76..e4d2c5b951d 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ | [Stable Diffusion 1.4 (512x512)](./models/demos/wormhole/stable_diffusion) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.167 | 0.3 | | | [YOLOv4 (320x320)](./models/demos/yolov4) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 95 | 300 | | | [SegFormer Semantic Segmentation (512x512)](./models/demos/segformer) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 90 | 300 | | +| [Stable Diffusion 3.5 medium (512x512)](https://github.com/tenstorrent/tt-metal/blob/mbahnas/sd35_medium_512_spacelike_feb05/models/experimental/stable_diffusion3) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.06 | 0.3 | | ## NLPs diff --git a/build_metal.sh b/build_metal.sh index 0ade43090ce..5d962c7472c 100755 --- a/build_metal.sh +++ b/build_metal.sh @@ -29,10 +29,13 @@ show_help() { echo " --clean Remove build workspaces." echo " --build-static-libs Build tt_metal (not ttnn) as a static lib (BUILD_SHARED_LIBS=OFF)" echo " --disable-unity-builds Disable Unity builds" + echo " --disable-light-metal-trace Disable Light Metal tracing to binary." echo " --cxx-compiler-path Set path to C++ compiler." echo " --c-compiler-path Set path to C++ compiler." + echo " --cpm-source-cache Set path to CPM Source Cache." echo " --ttnn-shared-sub-libs Use shared libraries for ttnn." echo " --toolchain-path Set path to CMake toolchain file." + echo " --configure-only Only configure the project, do not build." } clean() { @@ -58,11 +61,14 @@ build_programming_examples="OFF" build_tt_train="OFF" build_static_libs="OFF" unity_builds="ON" +light_metal_trace="ON" build_all="OFF" cxx_compiler_path="" +cpm_source_cache="" c_compiler_path="" ttnn_shared_sub_libs="OFF" toolchain_path="cmake/x86_64-linux-clang-17-libcpp-toolchain.cmake" +configure_only="OFF" declare -a cmake_args @@ -88,14 +94,17 @@ build-programming-examples build-tt-train build-static-libs disable-unity-builds +disable-light-metal-trace release development debug clean cxx-compiler-path: +cpm-source-cache: c-compiler-path: ttnn-shared-sub-libs toolchain-path: +configure-only " # Flatten LONGOPTIONS into a comma-separated string for getopt @@ -153,10 +162,16 @@ while true; do build_all="ON";; --ttnn-shared-sub-libs) ttnn_shared_sub_libs="ON";; + --configure-only) + configure_only="ON";; --disable-unity-builds) unity_builds="OFF";; + --disable-light-metal-trace) + light_metal_trace="OFF";; --cxx-compiler-path) cxx_compiler_path="$2";shift;; + --cpm-source-cache) + cpm_source_cache="$2";shift;; --c-compiler-path) c_compiler_path="$2";shift;; --toolchain-path) @@ -218,6 +233,7 @@ echo "INFO: Install Prefix: $cmake_install_prefix" echo "INFO: Build tests: $build_tests" echo "INFO: Enable Unity builds: $unity_builds" echo "INFO: TTNN Shared sub libs : $ttnn_shared_sub_libs" +echo "INFO: Enable Light Metal Trace: $light_metal_trace" # Prepare cmake arguments cmake_args+=("-B" "$build_dir") @@ -234,6 +250,11 @@ if [ "$c_compiler_path" != "" ]; then cmake_args+=("-DCMAKE_C_COMPILER=$c_compiler_path") fi +if [ "$cpm_source_cache" != "" ]; then + echo "INFO: CPM_SOURCE_CACHE: $cpm_source_cache" + cmake_args+=("-DCPM_SOURCE_CACHE=$cpm_source_cache") +fi + if [ "$enable_ccache" = "ON" ]; then cmake_args+=("-DCMAKE_DISABLE_PRECOMPILE_HEADERS=TRUE") cmake_args+=("-DENABLE_CCACHE=TRUE") @@ -308,6 +329,12 @@ else cmake_args+=("-DTT_UNITY_BUILDS=OFF") fi +if [ "$light_metal_trace" = "ON" ]; then + cmake_args+=("-DTT_ENABLE_LIGHT_METAL_TRACE=ON") +else + cmake_args+=("-DTT_ENABLE_LIGHT_METAL_TRACE=OFF") +fi + if [ "$build_all" = "ON" ]; then cmake_args+=("-DTT_METAL_BUILD_TESTS=ON") cmake_args+=("-DTTNN_BUILD_TESTS=ON") @@ -331,5 +358,7 @@ echo "INFO: Running: cmake "${cmake_args[@]}"" cmake "${cmake_args[@]}" # Build libraries and cpp tests -echo "INFO: Building Project" -cmake --build $build_dir --target install +if [ "$configure_only" = "OFF" ]; then + echo "INFO: Building Project" + cmake --build $build_dir --target install +fi diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake index 3ec3685c7f1..c65a6dadb5b 100644 --- a/cmake/CPM.cmake +++ b/cmake/CPM.cmake @@ -5,13 +5,10 @@ set(CPM_DOWNLOAD_VERSION 0.40.2) set(CPM_HASH_SUM "c8cdc32c03816538ce22781ed72964dc864b2a34a310d3b7104812a5ca2d835d") -if(CPM_SOURCE_CACHE) - set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") -elseif(DEFINED ENV{CPM_SOURCE_CACHE}) - set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") -else() - set(CPM_DOWNLOAD_LOCATION "${PROJECT_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") -endif() +# Always Require the CMake option, but provide default +set(CPM_SOURCE_CACHE "${CMAKE_SOURCE_DIR}/.cpmcache" CACHE STRING "Path to CPM source cache") + +set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") # Expand relative path. This is important if the provided path contains a tilde (~) get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE) @@ -23,5 +20,4 @@ file( EXPECTED_HASH SHA256=${CPM_HASH_SUM} ) -set(ENV{CPM_SOURCE_CACHE} "${PROJECT_SOURCE_DIR}/.cpmcache") include(${CPM_DOWNLOAD_LOCATION}) diff --git a/cmake/fetch_boost.cmake b/cmake/fetch_boost.cmake deleted file mode 100644 index 4987d256c45..00000000000 --- a/cmake/fetch_boost.cmake +++ /dev/null @@ -1,27 +0,0 @@ -include(${PROJECT_SOURCE_DIR}/cmake/CPM.cmake) - -function(fetch_boost_library BOOST_PROJECT_NAME) - CPMAddPackage( - NAME boost_${BOOST_PROJECT_NAME} - GITHUB_REPOSITORY boostorg/${BOOST_PROJECT_NAME} - GIT_TAG boost-1.85.0 - OPTIONS - "BUILD_SHARED_LIBS OFF" - ) - - get_target_property(BOOST_INTERFACE_LINK_LIBRARIES boost_${BOOST_PROJECT_NAME} INTERFACE_LINK_LIBRARIES) - - if(NOT BOOST_INTERFACE_LINK_LIBRARIES STREQUAL BOOST_INTERFACE_LINK_LIBRARIES-NOTFOUND) - foreach(BOOST_INTERFACE_LINK_LIBRARY IN ITEMS ${BOOST_INTERFACE_LINK_LIBRARIES}) - if( - NOT TARGET - ${BOOST_INTERFACE_LINK_LIBRARY} - AND BOOST_INTERFACE_LINK_LIBRARY - MATCHES - "^Boost::([a-z0-9_]+)$" - ) - fetch_boost_library(${CMAKE_MATCH_1}) - endif() - endforeach() - endif() -endfunction() diff --git a/cmake/project_options.cmake b/cmake/project_options.cmake index 3187b2efc10..3937b609500 100644 --- a/cmake/project_options.cmake +++ b/cmake/project_options.cmake @@ -19,6 +19,7 @@ option(ENABLE_CCACHE "Build with compiler cache" FALSE) option(TT_UNITY_BUILDS "Build with Unity builds" ON) option(BUILD_TT_TRAIN "Enables build of tt-train" OFF) option(ENABLE_TTNN_SHARED_SUBLIBS "Use shared libraries for ttnn to speed up incremental builds" OFF) +option(TT_ENABLE_LIGHT_METAL_TRACE "Enable Light Metal Trace" ON) ########################################################################################### diff --git a/dependencies/CMakeLists.txt b/dependencies/CMakeLists.txt index 3daf27f42d0..793e7f8c859 100644 --- a/dependencies/CMakeLists.txt +++ b/dependencies/CMakeLists.txt @@ -8,16 +8,26 @@ set(CMAKE_CXX_CLANG_TIDY "") # Boost ############################################################################################################################ -include(${PROJECT_SOURCE_DIR}/cmake/fetch_boost.cmake) - -fetch_boost_library(core) -fetch_boost_library(smart_ptr) -fetch_boost_library(container) -fetch_boost_library(interprocess) +CPMAddPackage( + NAME Boost + VERSION 1.86.0 + URL + https://github.com/boostorg/boost/releases/download/boost-1.86.0/boost-1.86.0-cmake.tar.xz + URL_HASH + SHA256=2c5ec5edcdff47ff55e27ed9560b0a0b94b07bd07ed9928b476150e16b0efc57 + OPTIONS + "BOOST_ENABLE_CMAKE ON" + "BOOST_SKIP_INSTALL_RULES ON" + "BUILD_SHARED_LIBS OFF" + "BOOST_INCLUDE_LIBRARIES core\\\;container\\\;smart_ptr\\\;interprocess" +) add_library(span INTERFACE) target_link_libraries(span INTERFACE Boost::core) +add_library(small_vector INTERFACE) +target_link_libraries(small_vector INTERFACE Boost::container) + ############################################################################################################################ # yaml-cpp ############################################################################################################################ diff --git a/dockerfile/Dockerfile b/dockerfile/Dockerfile new file mode 100644 index 00000000000..e1a388d2f2b --- /dev/null +++ b/dockerfile/Dockerfile @@ -0,0 +1,125 @@ + +############################################################# + +# Accept an argument to specify the Ubuntu version +ARG UBUNTU_VERSION=20.04 +FROM public.ecr.aws/ubuntu/ubuntu:${UBUNTU_VERSION} AS base + +ENV DEBIAN_FRONTEND=noninteractive + +# Install runtime deps +COPY /install_dependencies.sh /opt/tt_metal_infra/scripts/docker/install_dependencies.sh +RUN /bin/bash /opt/tt_metal_infra/scripts/docker/install_dependencies.sh --docker --mode runtime + + +############################################################# + +FROM base AS ci-build + +RUN /bin/bash /opt/tt_metal_infra/scripts/docker/install_dependencies.sh --docker --mode build + +# Install ccache from upstream; Apt's version for 20.04 predates remote_storage support +RUN mkdir -p /usr/local/bin && wget -O /tmp/ccache.tar.xz https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz && \ + tar -xf /tmp/ccache.tar.xz -C /usr/local/bin --strip-components=1 && \ + rm /tmp/ccache.tar.xz + +ARG DOXYGEN_VERSION=1.9.6 +RUN mkdir -p /tmp/doxygen \ + && wget -O /tmp/doxygen/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz "https://www.doxygen.nl/files/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz" \ + && tar -xzf /tmp/doxygen/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz -C /tmp/doxygen --strip-components=1 \ + && make -C /tmp/doxygen -j$(nproc) \ + && make -C /tmp/doxygen install \ + && rm -rf /tmp/doxygen + +RUN mkdir -p /tmp/cba \ + && wget -O /tmp/cba/cba.tar.gz https://github.com/aras-p/ClangBuildAnalyzer/archive/refs/tags/v1.6.0.tar.gz \ + && tar -xzf /tmp/cba/cba.tar.gz -C /tmp/cba --strip-components=1 \ + && cmake -S /tmp/cba/ -B /tmp/cba/build -DCMAKE_BUILD_TYPE=Release \ + && cmake --build /tmp/cba/build \ + && cmake --install /tmp/cba/build \ + && rm -rf /tmp/cba + +# Install extra ci apt requirements +RUN apt-get update && apt-get install -y --no-install-recommends \ + apt-utils \ + bc \ + clang-tidy-17 \ + curl \ + dialog \ + graphviz \ + jq \ + pandoc \ + sudo \ + wget \ + libtbb-dev \ + libcapstone-dev \ + libfmt-dev \ + libyaml-cpp-dev \ + pybind11-dev \ + nlohmann-json3-dev \ + libgtest-dev \ + libboost-all-dev \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +ENV CCACHE_TEMPDIR=/tmp/ccache + +############################################################# + +FROM ci-build AS ci-test + +ARG TT_METAL_INFRA_DIR=/opt/tt_metal_infra + +# Create directories for infra +RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/docs/ +RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/ +RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/ +# Copy requirements from tt-metal folders with requirements.txt docs +COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. +# Copy requirements from tt-metal folders for sweeps (requirements-sweeps.txt) +COPY /tests/sweep_framework/requirements-sweeps.txt ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/. +COPY /tt_metal/python_env/requirements-dev.txt ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. + +RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu && \ + python3 -m pip install setuptools wheel && \ + python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt && \ + python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/docs/requirements-docs.txt + +############################################################# + +FROM ci-test AS dev + +# Need this to build GDB +RUN apt-get -y update \ + && apt-get install -y libmpfr-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install the gdb that is compatible with clang-17 +RUN apt-get remove -y gdb || true \ + && mkdir -p /tmp/gdb-build && cd /tmp/gdb-build/ \ + && wget -O /tmp/gdb-build/gdb.tar.gz https://ftp.gnu.org/gnu/gdb/gdb-14.2.tar.gz \ + && tar -xvf /tmp/gdb-build/gdb.tar.gz -C /tmp/gdb-build --strip-components=1 \ + && /tmp/gdb-build/configure --prefix=/usr/local \ + && make -j$(nproc) \ + && make install \ + && rm -rf /tmp/gdb-build + +# Install dev deps +RUN apt-get update && apt-get install -y --no-install-recommends \ + acl \ + emacs \ + less \ + nano \ + openssh-server \ + vim \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +############################################################# + +FROM dev AS release + +RUN mkdir -p /etc && \ + echo "[global]\nextra-index-url = https://download.pytorch.org/whl/cpu" > /etc/pip.conf + +ARG WHEEL_FILENAME +ADD $WHEEL_FILENAME $WHEEL_FILENAME +RUN pip3 install $WHEEL_FILENAME diff --git a/dockerfile/release.Dockerfile b/dockerfile/release.Dockerfile deleted file mode 100644 index 4f6fc2dc951..00000000000 --- a/dockerfile/release.Dockerfile +++ /dev/null @@ -1,10 +0,0 @@ -ARG BASE_IMAGE_NAME=tt-metalium/ubuntu-20.04-amd64 -# -# Currently the release image uses the base image which is also the build image. -# However, in the future, we could point a true base image that is a base for both releases and builds. -# This work is described in https://github.com/tenstorrent/tt-metal/issues/11974 -FROM ghcr.io/tenstorrent/tt-metal/$BASE_IMAGE_NAME - -ARG WHEEL_FILENAME -ADD $WHEEL_FILENAME $WHEEL_FILENAME -RUN pip3 install $WHEEL_FILENAME diff --git a/dockerfile/ubuntu-20.04-amd64.Dockerfile b/dockerfile/ubuntu-20.04-amd64.Dockerfile deleted file mode 100644 index 83c899a65f5..00000000000 --- a/dockerfile/ubuntu-20.04-amd64.Dockerfile +++ /dev/null @@ -1,92 +0,0 @@ -# TT-METAL UBUNTU 20.04 AMD64 DOCKERFILE -FROM public.ecr.aws/ubuntu/ubuntu:20.04 - -ARG DEBIAN_FRONTEND=noninteractive -ENV DOXYGEN_VERSION=1.9.6 -ARG UBUNTU_VERSION=20.04 -ENV CCACHE_TEMPDIR=/tmp/ccache - -# Use a newer version of CMake than what is available from Canonical for 20.04 -RUN apt -y update \ - && apt install -y --no-install-recommends ca-certificates gpg wget \ - && wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null \ - && echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ focal main' | tee /etc/apt/sources.list.d/kitware.list >/dev/null \ - && rm -rf /var/lib/apt/lists/* - -# Install build and runtime deps -COPY /scripts/docker/requirements-${UBUNTU_VERSION}.txt /opt/tt_metal_infra/scripts/docker/requirements.txt -RUN apt-get -y update \ - && xargs -a /opt/tt_metal_infra/scripts/docker/requirements.txt apt-get install -y --no-install-recommends \ - && rm -rf /var/lib/apt/lists/* - -# Install dev deps -COPY /scripts/docker/requirements_dev.txt /opt/tt_metal_infra/scripts/docker/requirements_dev.txt -RUN apt-get -y update \ - && xargs -a /opt/tt_metal_infra/scripts/docker/requirements_dev.txt apt-get install -y --no-install-recommends \ - && rm -rf /var/lib/apt/lists/* - -## Test Related Dependencies -COPY /scripts/docker/install_test_deps.sh /opt/tt_metal_infra/scripts/docker/install_test_deps.sh -RUN /bin/bash /opt/tt_metal_infra/scripts/docker/install_test_deps.sh ${DOXYGEN_VERSION} - -# Copy remaining convenience scripts -COPY /scripts /opt/tt_metal_infra/scripts -COPY build_metal.sh /scripts/build_metal.sh - -# Setup Env variables to setup Python Virtualenv - Install TT-Metal Python deps -ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra -ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env - -# Disable using venv since this is isolated in a docker container -# RUN python3 -m venv $PYTHON_ENV_DIR -# ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" - -# Create directories for infra -RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/docs/ -RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/ -RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/ - -# Copy requirements from tt-metal folders with requirements.txt docs -COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. -# Copy requirements from tt-metal folders for sweeps (requirements-sweeps.txt) -COPY /tests/sweep_framework/requirements-sweeps.txt ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/. -COPY /tt_metal/python_env/requirements-dev.txt ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. - -RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ - && python3 -m pip install setuptools wheel - -RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt -RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/docs/requirements-docs.txt - -# Install Clang-17 -RUN cd $TT_METAL_INFRA_DIR \ - && wget https://apt.llvm.org/llvm.sh \ - && chmod u+x llvm.sh \ - && ./llvm.sh 17 - -# Install compatible gdb debugger for clang-17 -RUN cd $TT_METAL_INFRA_DIR \ - && wget https://ftp.gnu.org/gnu/gdb/gdb-14.2.tar.gz \ - && tar -xvf gdb-14.2.tar.gz \ - && cd gdb-14.2 \ - && ./configure \ - && make -j$(nproc) -ENV PATH="$TT_METAL_INFRA_DIR/gdb-14.2/gdb:$PATH" - -# Can only be installed after Clang-17 installed -RUN apt-get -y update \ - && apt-get install -y --no-install-recommends \ - libc++-17-dev \ - libc++abi-17-dev \ - clang-tidy-17 \ - && rm -rf /var/lib/apt/lists/* - -RUN mkdir -p /usr/app - -# Install ccache from upstream; Apt's version for 20.04 predates remote_storage support -RUN wget -O /tmp/ccache.tar.xz https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz && \ - tar -xf /tmp/ccache.tar.xz -C /usr/local/bin --strip-components=1 && \ - rm /tmp/ccache.tar.xz -RUN ccache --version - -CMD ["tail", "-f", "/dev/null"] diff --git a/dockerfile/ubuntu-22.04-amd64.Dockerfile b/dockerfile/ubuntu-22.04-amd64.Dockerfile deleted file mode 100644 index 5233338f3ee..00000000000 --- a/dockerfile/ubuntu-22.04-amd64.Dockerfile +++ /dev/null @@ -1,126 +0,0 @@ -# TT-METAL UBUNTU 22.04 AMD64 DOCKERFILE -FROM public.ecr.aws/ubuntu/ubuntu:22.04 - -ARG DEBIAN_FRONTEND=noninteractive -ARG UBUNTU_VERSION=22.04 -ENV DOXYGEN_VERSION=1.9.6 -ENV CCACHE_TEMPDIR=/tmp/ccache - -# Use a newer version of CMake than what is available from Canonical for 22.04 -RUN apt -y update \ - && apt install -y --no-install-recommends ca-certificates gpg wget \ - && wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /usr/share/keyrings/kitware-archive-keyring.gpg >/dev/null \ - && echo 'deb [signed-by=/usr/share/keyrings/kitware-archive-keyring.gpg] https://apt.kitware.com/ubuntu/ jammy main' | tee /etc/apt/sources.list.d/kitware.list >/dev/null \ - && rm -rf /var/lib/apt/lists/* - -RUN apt update -y && apt install software-properties-common gpg-agent -y - -# add custom repo -RUN add-apt-repository ppa:deadsnakes/ppa - -# Install build and runtime deps -COPY /scripts/docker/requirements-${UBUNTU_VERSION}.txt /opt/tt_metal_infra/scripts/docker/requirements.txt -RUN apt-get -y update \ - && xargs -a /opt/tt_metal_infra/scripts/docker/requirements.txt apt-get install -y --no-install-recommends \ - && rm -rf /var/lib/apt/lists/* - -# Install dev deps -COPY /scripts/docker/requirements_dev.txt /opt/tt_metal_infra/scripts/docker/requirements_dev.txt -RUN apt-get -y update \ - && xargs -a /opt/tt_metal_infra/scripts/docker/requirements_dev.txt apt-get install -y --no-install-recommends \ - && rm -rf /var/lib/apt/lists/* - -## Test Related Dependencies -COPY /scripts/docker/install_test_deps.sh /opt/tt_metal_infra/scripts/docker/install_test_deps.sh -RUN /bin/bash /opt/tt_metal_infra/scripts/docker/install_test_deps.sh ${DOXYGEN_VERSION} - -# Copy remaining convenience scripts -COPY /scripts /opt/tt_metal_infra/scripts -COPY build_metal.sh /scripts/build_metal.sh - -# Setup Env variables to setup Python Virtualenv - Install TT-Metal Python deps -ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra -ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env - -# Disable using venv since this is isolated in a docker container -# RUN python3 -m venv $PYTHON_ENV_DIR -# ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" - -# Create directories for infra -RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/docs/ -RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/ -RUN mkdir -p ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/ - -# Copy requirements from tt-metal folders with requirements.txt docs -COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. -# Copy requirements from tt-metal folders for sweeps (requirements-sweeps.txt) -COPY /tests/sweep_framework/requirements-sweeps.txt ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/. -COPY /tt_metal/python_env/* ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. -RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ - && python3 -m pip install setuptools wheel - -RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt -RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/docs/requirements-docs.txt - -# Install Clang-17 -RUN cd $TT_METAL_INFRA_DIR \ - && wget https://apt.llvm.org/llvm.sh \ - && chmod u+x llvm.sh \ - && ./llvm.sh 17 - -# Install compatible gdb debugger for clang-17 -RUN cd $TT_METAL_INFRA_DIR \ - && wget https://ftp.gnu.org/gnu/gdb/gdb-14.2.tar.gz \ - && tar -xvf gdb-14.2.tar.gz \ - && cd gdb-14.2 \ - && ./configure \ - && make -j$(nproc) -ENV PATH="$TT_METAL_INFRA_DIR/gdb-14.2/gdb:$PATH" - -# Can only be installed after Clang-17 installed -RUN apt-get -y update \ - && apt-get install -y --no-install-recommends \ - libc++-17-dev \ - libc++abi-17-dev \ - clang-tidy-17 \ - && rm -rf /var/lib/apt/lists/* - -# Setup Env variables to setup Python Virtualenv - Install TT-Metal Python deps -ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra -ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env - -# Disable using venv since this is isolated in a docker container -# RUN python3 -m venv $PYTHON_ENV_DIR -# ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" - -# Copy requirements from tt-metal folders with requirements.txt docs -COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. -# Copy requirements from tt-metal folders for sweeps (requirements-sweeps.txt) -COPY /tests/sweep_framework/requirements-sweeps.txt ${TT_METAL_INFRA_DIR}/tt-metal/tests/sweep_framework/. -COPY /tt_metal/python_env/* ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. -RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ - && python3 -m pip install setuptools wheel - -RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt -RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/docs/requirements-docs.txt - -# Set python 3.11 and gcc-12 to be default -# RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 10 -# RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 11 - -RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 11 -RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-12 12 - -RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 11 -RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 12 - -# Ccache is not in requirements_dev.txt because 20.04's version is too old for remote_storage support. -# When we drop 20.04, can put it in requirements_dev.txt instead of here. -RUN wget -O /tmp/ccache.tar.xz https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz && \ - tar -xf /tmp/ccache.tar.xz -C /usr/local/bin --strip-components=1 && \ - rm /tmp/ccache.tar.xz -RUN ccache --version - -RUN mkdir -p /usr/app - -# CMD ["tail", "-f", "/dev/null"] diff --git a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/binary_op_init_funcs.rst b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/binary_op_init_funcs.rst index c0d31f72b9c..b1c393a4cdd 100644 --- a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/binary_op_init_funcs.rst +++ b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/binary_op_init_funcs.rst @@ -3,4 +3,4 @@ binary_init_funcs .. doxygenfunction:: binary_op_init_common(uint32_t icb0, uint32_t icb1, uint32_t ocb) -.. doxygenfunction:: binary_op_specific_init() +.. doxygenfunction:: binary_op_specific_init(uint32_t icb0, uint32_t icb1) diff --git a/infra/data_collection/github/download_cicd_logs_and_artifacts.sh b/infra/data_collection/github/download_cicd_logs_and_artifacts.sh index cc25c1014d3..4e05809206a 100755 --- a/infra/data_collection/github/download_cicd_logs_and_artifacts.sh +++ b/infra/data_collection/github/download_cicd_logs_and_artifacts.sh @@ -27,24 +27,20 @@ download_artifacts() { download_logs_for_all_jobs() { local repo=$1 local workflow_run_id=$2 - local max_attempts=$3 - - echo "[info] downloading logs for job with id $job_id for all attempts up to $max_attempts" - for attempt_number in $(seq 1 $max_attempts); do - echo "[Info] Downloading for attempt $attempt_number" - - gh api /repos/$repo/actions/runs/$workflow_run_id/attempts/$attempt_number/jobs --paginate | jq -c '.jobs[] | {id: .id, conclusion: .conclusion}' | while read -r job; do - job_id=$(echo "$job" | jq -r '.id') - job_conclusion=$(echo "$job" | jq -r '.conclusion') - echo "[info] download logs for job with id $job_id, attempt number $attempt_number" - gh api /repos/$repo/actions/jobs/$job_id/logs > generated/cicd/$workflow_run_id/logs/$job_id.log - - # Only download annotations for failed jobs - if [[ "$job_conclusion" == "failure" ]]; then - echo "[info] downloading annotations for failed job $job_id" - gh api /repos/$repo/check-runs/$job_id/annotations > generated/cicd/$workflow_run_id/logs/${job_id}_annotations.json - fi - done + local attempt_number=$3 + + echo "[info] Downloading logs for workflow with id $workflow_run_id for attempt $attempt_number" + gh api /repos/$repo/actions/runs/$workflow_run_id/attempts/$attempt_number/jobs --paginate | jq -c '.jobs[] | {id: .id, conclusion: .conclusion}' | while read -r job; do + job_id=$(echo "$job" | jq -r '.id') + job_conclusion=$(echo "$job" | jq -r '.conclusion') + echo "[info] download logs for job with id $job_id, attempt number $attempt_number" + gh api /repos/$repo/actions/jobs/$job_id/logs > generated/cicd/$workflow_run_id/logs/$job_id.log + + # Only download annotations for failed jobs + if [[ "$job_conclusion" == "failure" ]]; then + echo "[info] downloading annotations for failed job $job_id" + gh api /repos/$repo/check-runs/$job_id/annotations > generated/cicd/$workflow_run_id/logs/${job_id}_annotations.json + fi done } diff --git a/infra/data_collection/github/utils.py b/infra/data_collection/github/utils.py index eb2edfa8e2a..b898ca00cd3 100644 --- a/infra/data_collection/github/utils.py +++ b/infra/data_collection/github/utils.py @@ -97,6 +97,8 @@ def get_job_failure_signature_(github_job, failure_description) -> Optional[Unio "lost communication with the server": str(InfraErrorV1.RUNNER_COMM_FAILURE), "runner has received a shutdown signal": str(InfraErrorV1.RUNNER_SHUTDOWN_FAILURE), "No space left on device": str(InfraErrorV1.DISK_SPACE_FAILURE), + "API rate limit exceeded": str(InfraErrorV1.API_RATE_LIMIT_FAILURE), + "Tenstorrent cards seem to be in use": str(InfraErrorV1.RUNNER_CARD_IN_USE_FAILURE), } # Check the mapping dictionary for specific failure signature types diff --git a/infra/data_collection/models.py b/infra/data_collection/models.py index 25c9452a06d..078e55d04c2 100644 --- a/infra/data_collection/models.py +++ b/infra/data_collection/models.py @@ -9,3 +9,5 @@ class InfraErrorV1(enum.Enum): DISK_SPACE_FAILURE = enum.auto() RUNNER_COMM_FAILURE = enum.auto() RUNNER_SHUTDOWN_FAILURE = enum.auto() + API_RATE_LIMIT_FAILURE = enum.auto() + RUNNER_CARD_IN_USE_FAILURE = enum.auto() diff --git a/install_dependencies.sh b/install_dependencies.sh index 8875a5d8def..292ffd8ac60 100755 --- a/install_dependencies.sh +++ b/install_dependencies.sh @@ -86,10 +86,13 @@ ub_runtime_packages() ub_buildtime_packages() { UB_BUILDTIME_LIST=(\ - libpython3-dev \ - python3-pip \ + git \ + python3-dev \ + pkg-config \ + cargo \ cmake \ ninja-build \ + libboost-dev \ libhwloc-dev \ libc++-17-dev \ libc++abi-17-dev \ diff --git a/models/bringup_testing/Tutorial_Adding_a_Model.md b/models/bringup_testing/Tutorial_Adding_a_Model.md deleted file mode 100644 index d6fe86ec449..00000000000 --- a/models/bringup_testing/Tutorial_Adding_a_Model.md +++ /dev/null @@ -1,37 +0,0 @@ -# Adding a Model - -## Basic Requirements - -- Access to TT-Hardware -- Knowledge of PyTorch and Transformers -- Familiarity with TT-Metalium and TTNN -- See [TT-Metal README.md](https://github.com/tenstorrent/tt-metal/blob/main/README.md) for the latest updates to Tenstorrent models. - -## Initial Model Bring-up - -1. Run a reference model to ensure you have correctly set up your model with correct weights, attributes, etc. See the [HuggingFace PyTorch GitHub](https://github.com/huggingface/pytorch-image-models) for reference models. -2. Decompose the model into modules for function implementation. Here are examples of standard modules used in LLMs: Layernorm/RMSNorm, RotaryEmbedding, Attention, or Multilayer Perceptron (MLP). -3. Compose all modules into higher level modules. Decoder Layer and Full Model are examples of higher level modules. -4. Implement decode and prefill modes. Both must be included for the model to function. -5. Unit test each module. Start with the smallest module working up to composite modules. -6. Implement the module in TTNN, then pass the same inputs to the reference module and the TTNN module to check for correctness. -7. Create a full model test. Use real inputs to produce real outputs; for LLMs, input text to output decoded tokens. -8. Run the same inputs through the reference model and TTNN model to check the accuracy of your implementation. Teacher forcing is the ideal method to use with LLMs. -9. Generate a token from the reference model and TTNN model. Input these reference tokens into both models in the next iteration. Depending on the differences in the outputs, you can check accuracy metrics. -10. See: [LLMs Bring up in TT-NN](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/LLMs/llms.md) or [ViT in TTNN](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/ViT-TTNN/vit.md) for more information on these steps. - -## Model Performance Optimization - -Optimization tools like Metal Trace, async mode, and multiple command queues improve the performance of your model. - -- Metal Trace - Metal Trace is a performance optimization tool that removes host overhead of constructing and dispatching operations. -- Async Mode - Async mode allows the host to continuously send commands without blocking util data is read back from the device. -- Multiple Command Queues - Metalium can support two command queues. These command queues are independent from each other and allow for parallel dispatches on the same device. -- See [Advanced Performance Optimizations for Models](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/AdvancedPerformanceOptimizationsForModels/AdvancedPerformanceOptimizationsForModels.md#1-metal-trace) for more information on performance optimization. - -## Run the Demo - -1. Download weights. -2. Setup environment variables. -3. Cache weights. -4. Execute the demo. diff --git a/models/demos/grayskull/resnet50/README.md b/models/demos/grayskull/resnet50/README.md index 5cb994d54b5..a42974646e8 100644 --- a/models/demos/grayskull/resnet50/README.md +++ b/models/demos/grayskull/resnet50/README.md @@ -7,7 +7,7 @@ ResNet50 is a deep convolutional neural network architecture with 50 layers, des ## Details -+ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50_new_conv_api.py`. ++ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50.py`. + The model picks up certain configs and weights from TorchVision pretrained model. We have used `torchvision.models.ResNet50_Weights.IMAGENET1K_V1` version from TorchVision as our reference. + Our ImageProcessor on the other hand is based on `microsoft/resnet-50` from huggingface. @@ -38,7 +38,7 @@ pytest --disable-warnings models/demos/grayskull/demo/demo.py::test_demo_imagene ### Single Device #### Grayskull Device Performance -+ To obtain device performance, run ++ To obtain device performance, run ```python pytest models/demos/grayskull/resnet50/tests/test_perf_device_resnet50.py::test_perf_device ``` diff --git a/models/demos/t3000/resnet50/README.md b/models/demos/t3000/resnet50/README.md index 5a045b1ec6b..78ce2e378fc 100644 --- a/models/demos/t3000/resnet50/README.md +++ b/models/demos/t3000/resnet50/README.md @@ -7,7 +7,7 @@ ResNet50 is a deep convolutional neural network architecture with 50 layers, des ## Details -+ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50_new_conv_api.py`. ++ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50.py`. + The model picks up certain configs and weights from TorchVision pretrained model. We have used `torchvision.models.ResNet50_Weights.IMAGENET1K_V1` version from TorchVision as our reference. + Our ImageProcessor on the other hand is based on `microsoft/resnet-50` from huggingface. diff --git a/models/demos/tg/resnet50/README.md b/models/demos/tg/resnet50/README.md index c48b7c38d55..f815cf22e79 100644 --- a/models/demos/tg/resnet50/README.md +++ b/models/demos/tg/resnet50/README.md @@ -7,7 +7,7 @@ ResNet50 is a deep convolutional neural network architecture with 50 layers, des ## Details -+ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50_new_conv_api.py`. ++ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50.py`. + The model picks up certain configs and weights from TorchVision pretrained model. We have used `torchvision.models.ResNet50_Weights.IMAGENET1K_V1` version from TorchVision as our reference. + Our ImageProcessor on the other hand is based on `microsoft/resnet-50` from huggingface. diff --git a/models/demos/tgg/resnet50/README.md b/models/demos/tgg/resnet50/README.md index 30dec4ec5b2..f119befd4d0 100644 --- a/models/demos/tgg/resnet50/README.md +++ b/models/demos/tgg/resnet50/README.md @@ -7,7 +7,7 @@ ResNet50 is a deep convolutional neural network architecture with 50 layers, des ## Details -+ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50_new_conv_api.py`. ++ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50.py`. + The model picks up certain configs and weights from TorchVision pretrained model. We have used `torchvision.models.ResNet50_Weights.IMAGENET1K_V1` version from TorchVision as our reference. + Our ImageProcessor on the other hand is based on `microsoft/resnet-50` from huggingface. diff --git a/models/demos/ttnn_resnet/README.md b/models/demos/ttnn_resnet/README.md index 655cfbe35f8..c4fabb9cb3d 100644 --- a/models/demos/ttnn_resnet/README.md +++ b/models/demos/ttnn_resnet/README.md @@ -7,7 +7,7 @@ ResNet50 is a deep convolutional neural network architecture with 50 layers, des ## Details -+ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50_new_conv_api.py`. ++ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50.py`. + The model picks up certain configs and weights from TorchVision pretrained model. We have used `torchvision.models.ResNet50_Weights.IMAGENET1K_V1` version from TorchVision as our reference. + Our ImageProcessor on the other hand is based on `microsoft/resnet-50` from huggingface. diff --git a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py index 0d7795b9c5e..c7a25d71e09 100644 --- a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py +++ b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py @@ -19,9 +19,10 @@ divup, ) -from tests.ttnn.utils_for_testing import assert_with_pcc +from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc from models.demos.ttnn_resnet.tt.custom_preprocessing import create_custom_mesh_preprocessor -from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50_new_conv_api import resnet50 + +from models.demos.ttnn_resnet.tt.ttnn_functional_resnet50 import resnet50 def load_resnet50_model(model_location_generator): @@ -145,6 +146,7 @@ def load_resnet50_model(model_location_generator): golden_pcc = { ttnn.device.Arch.WORMHOLE_B0: copy.deepcopy(golden_pcc_obj), ttnn.device.Arch.GRAYSKULL: copy.deepcopy(golden_pcc_obj), + ttnn.device.Arch.BLACKHOLE: copy.deepcopy(golden_pcc_obj), } golden_pcc[ttnn.device.Arch.GRAYSKULL][16][ @@ -331,12 +333,14 @@ def validate(self, output_tensor=None): valid_pcc = 0.93 else: valid_pcc = 0.982 - self.pcc_passed, self.pcc_message = assert_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc) + self.pcc_passed, self.pcc_message = check_with_pcc(self.torch_output_tensor, output_tensor, pcc=valid_pcc) logger.info( f"ResNet50 batch_size={batch_size}, act_dtype={self.act_dtype}, weight_dtype={self.weight_dtype}, math_fidelity={self.math_fidelity}, PCC={self.pcc_message}" ) + return self.pcc_passed, self.pcc_message + def create_test_infra( device, diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py similarity index 83% rename from models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py rename to models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py index 2ebb7f869fa..2150dfc7d1d 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py @@ -8,6 +8,7 @@ from models.utility_functions import ( is_grayskull, is_wormhole_b0, + is_blackhole, _nearest_y, ) from typing import List @@ -50,6 +51,10 @@ ), } +ops_parallel_config = { + "layer1_module1_input": None, +} + def ResnetLinear( in_features: int, @@ -153,7 +158,7 @@ def run_downsample_if_req( reshard_if_not_optimal=False, height_sharding=None, transpose_shards=True, - packer_l1_accum_enabled=True if is_wormhole_b0() else False, + packer_l1_accum_enabled=True if not is_grayskull() else False, enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, @@ -179,7 +184,7 @@ def run_downsample_if_req( if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED, deallocate_activation=True, - reallocate_halo_output=not (is_wormhole_b0() and batch_size == 16), + reallocate_halo_output=True, reshard_if_not_optimal=reshard_if_not_optimal, transpose_shards=transpose_shards, enable_act_double_buffer=enable_act_double_buffer @@ -192,6 +197,8 @@ def run_downsample_if_req( enable_subblock_padding=enable_subblock_padding, ), } + if is_blackhole(): + conv_kwargs["conv_config"].enable_split_reader = False if not ttnn.is_tensor_storage_on_device(self.ds_conv_weight_tensor): self.ds_conv_weight_tensor = ttnn.prepare_conv_weights( @@ -223,6 +230,8 @@ def run_downsample_if_req( packer_l1_acc=packer_l1_accum_enabled, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=False, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -242,18 +251,24 @@ def __call__( height_sharding=None, eltwise_binary_out_in_place=True, transpose_shards=True, - packer_l1_acc=True if is_wormhole_b0() else False, + packer_l1_acc=True if not is_grayskull() else False, enable_act_double_buffer=False, enable_split_reader=False, enable_subblock_padding=False, + ops_parallel_config=None, + layer_module=None, ): logger.debug( f"==== Running {batch_size}, {input_height}, {input_width}, {self.conv1_input_channels}, {self.conv1_output_channels}" ) + ds_input_height = input_height + ds_input_width = input_width + # conv1 is 1x1 conv logger.debug(f"Running conv1") module_input_height = input_height + module_input_width = input_width conv_kwargs_1 = { "in_channels": self.conv1_input_channels, "out_channels": self.conv1_output_channels, @@ -277,6 +292,8 @@ def __call__( transpose_shards=transpose_shards, ), } + if is_blackhole(): + conv_kwargs_1["conv_config"].enable_split_reader = False if not ttnn.is_tensor_storage_on_device(self.conv1_weight_tensor): self.conv1_weight_tensor = ttnn.prepare_conv_weights( @@ -313,46 +330,30 @@ def __call__( ) act_block_h_override = 0 + run_downsample_before_conv2 = True + ds_out = None + if is_grayskull(): if self.conv2_output_channels == 64 and input_height == 56 and batch_size == 20: act_block_h_override = 320 elif is_wormhole_b0(): - if ( - self.conv2_input_channels == 128 - and self.conv2_output_channels == 128 - and input_height == 56 - and batch_size == 20 - ): - act_block_h_override = 160 - - run_downsample_before_conv2 = False - if not (input_height == 56 and self.conv1_input_channels == 64): - run_downsample_before_conv2 = True - if ( - is_wormhole_b0() - and batch_size == 16 - and ( - (input_height == 56 and self.conv1_input_channels == 256 and self.conv1_output_channels == 128) - or (input_height == 28 and self.conv1_input_channels == 512 and self.conv1_output_channels == 256) - or (input_height == 14 and self.conv1_input_channels == 1024 and self.conv1_output_channels == 512) - ) - ): - run_downsample_before_conv2 = True + run_downsample_before_conv2 = False - # ds_mem_config_grid = None if run_downsample_before_conv2: - if input_height == 56 and self.conv1_input_channels == 256 and self.downsample: - x_rm = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) - ttnn.deallocate(x) - if is_wormhole_b0(): - out = ttnn.reallocate(out) - x = ttnn.reallocate(x_rm) + if layer_module and layer_module == "layer4_module1": + if ops_parallel_config and "layer4_module1_downsample" in ops_parallel_config: + x = ttnn.to_memory_config(x, ops_parallel_config["layer4_module1_downsample"]) + if is_grayskull(): + if input_height == 56 and self.conv1_input_channels == 256 and self.downsample: + x_rm = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + ttnn.deallocate(x) + x = ttnn.reallocate(x_rm) ds_out = self.run_downsample_if_req( x, device, batch_size, - input_height, - input_width, + ds_input_height, + ds_input_width, conv_op_cache, reshard_if_not_optimal, height_sharding, @@ -362,9 +363,24 @@ def __call__( enable_split_reader=enable_split_reader, enable_subblock_padding=enable_subblock_padding, ) + if layer_module and layer_module == "layer4_module1": + if ops_parallel_config and "layer4_module1_downsample" not in ops_parallel_config: + x_memory_config = ttnn.get_memory_config(ds_out) + sharded_config = ttnn.create_sharded_memory_config_( + ttnn.Shape([batch_size, ds_input_height, ds_input_width, self.conv1_input_channels]), + x_memory_config.shard_spec.grid, + x_memory_config.memory_layout, + x_memory_config.shard_spec.orientation, + tile_layout=True, + ) + ops_parallel_config["layer4_module1_downsample"] = sharded_config - reallocate_halo_output = batch_size == 20 logger.debug(f"Running conv2") + + if layer_module and layer_module == "layer4_module1": + if ops_parallel_config and "layer4_module1_input" in ops_parallel_config: + out = ttnn.to_memory_config(out, ops_parallel_config["layer4_module1_input"]) + conv_kwargs_2 = { "in_channels": self.conv2_input_channels, "out_channels": self.conv2_output_channels, @@ -382,7 +398,7 @@ def __call__( weights_dtype=self.model_config["WEIGHTS_DTYPE"], activation="relu", deallocate_activation=True, - reallocate_halo_output=reallocate_halo_output, + reallocate_halo_output=not is_wormhole_b0(), act_block_h_override=act_block_h_override, shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding @@ -396,6 +412,11 @@ def __call__( ), } + if is_blackhole(): + conv_kwargs_2["conv_config"].act_block_h_override = 2 * 32 + conv_kwargs_2["conv_config"].enable_subblock_padding = False + conv_kwargs_2["conv_config"].enable_split_reader = False + if not ttnn.is_tensor_storage_on_device(self.conv2_weight_tensor): self.conv2_weight_tensor = ttnn.prepare_conv_weights( weight_tensor=self.conv2_weight_tensor, @@ -428,20 +449,17 @@ def __call__( return_output_dim=True, return_weights_and_bias=False, ) - - logger.debug( - f"{batch_size} and {input_height} and {self.conv1_input_channels} and {self.conv1_output_channels}" - ) - - if ( - is_wormhole_b0() - and batch_size == 20 - and input_height == 28 - and self.conv1_input_channels == 256 - and self.conv1_output_channels == 128 - ): - logger.info(f"==== Reallocating conv2 output") - out = ttnn.reallocate(out) + if layer_module and layer_module == "layer4_module1": + if ops_parallel_config and "layer4_module1_input" not in ops_parallel_config: + x_memory_config = ttnn.get_memory_config(out) + sharded_config = ttnn.create_sharded_memory_config_( + ttnn.Shape([batch_size, module_input_height, module_input_width, self.conv2_input_channels]), + x_memory_config.shard_spec.grid, + x_memory_config.memory_layout, + x_memory_config.shard_spec.orientation, + tile_layout=True, + ) + ops_parallel_config["layer4_module1_input"] = sharded_config # conv3 is 1x1 conv logger.debug(f"Running conv3") @@ -467,6 +485,8 @@ def __call__( transpose_shards=transpose_shards, ), } + if is_blackhole(): + conv_kwargs_3["conv_config"].enable_split_reader = False if not ttnn.is_tensor_storage_on_device(self.conv3_weight_tensor): self.conv3_weight_tensor = ttnn.prepare_conv_weights( @@ -496,29 +516,19 @@ def __call__( packer_l1_acc=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=False, ) if not run_downsample_before_conv2: - ds_reshard = ( - False - if is_grayskull() - and batch_size == 20 - and ( - input_height == 28 - and self.conv1_input_channels == 256 - or input_height == 14 - and self.conv1_input_channels == 512 - ) - else reshard_if_not_optimal - ) ds_out = self.run_downsample_if_req( x, device, batch_size, - input_height, - input_width, + ds_input_height, + ds_input_width, conv_op_cache, - ds_reshard, + reshard_if_not_optimal, height_sharding, transpose_shards=transpose_shards, packer_l1_accum_enabled=packer_l1_acc, @@ -527,6 +537,8 @@ def __call__( enable_subblock_padding=enable_subblock_padding, ) + assert ds_out is not None, "ds_out is None" + assert ttnn.get_memory_config(out) == ttnn.get_memory_config( ds_out ), f"{ttnn.get_memory_config(out)} != {ttnn.get_memory_config(ds_out)}" @@ -544,10 +556,8 @@ def __call__( ds_out, activations=[ttnn.UnaryWithParam(ttnn.UnaryOpType.RELU)], memory_config=ttnn.L1_MEMORY_CONFIG, - ) ## TODO: check why not out mem config??? + ) ttnn.deallocate(ds_out) - if batch_size == 20 and (is_wormhole_b0() or (module_input_height == 56 and self.conv1_input_channels == 64)): - out = ttnn.reallocate(out) return out, input_height, input_width @@ -575,18 +585,13 @@ def __init__( self.conv_op_cache = {} self.inplanes = 64 self.final_output_mem_config = final_output_mem_config - if is_grayskull(): - compute_kernel_config = ttnn.GrayskullComputeKernelConfig( - math_fidelity=model_config["MATH_FIDELITY"], - math_approx_mode=True, - ) - else: - compute_kernel_config = ttnn.WormholeComputeKernelConfig( - math_fidelity=model_config["MATH_FIDELITY"], - math_approx_mode=True, - fp32_dest_acc_en=False, - packer_l1_acc=True, - ) + compute_kernel_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) self.conv1_weight_tensor = parameters.conv1.weight self.conv1_bias_tensor = parameters.conv1.bias self.conv1_input_channels = self.conv1_weight_tensor.shape[1] @@ -663,22 +668,12 @@ def __init__( self.transpose_shards = True act_block_h_override = 1568 - if is_wormhole_b0(): + if is_wormhole_b0() or is_blackhole(): self.transpose_shards = False - if batch_size == 16: - act_block_h_override = 1568 - elif batch_size == 20: - act_block_h_override = 640 else: act_block_h_override = 0 - # input_channels_alignment = 16 if not is_wormhole_b0() else 32 - whb0_and_b16 = is_wormhole_b0() and self.batch_size == 16 - if not is_wormhole_b0(): - input_channels_alignment = 16 - elif whb0_and_b16: - input_channels_alignment = 16 - else: - input_channels_alignment = 32 + + input_channels_alignment = 16 self.conv1_config = ttnn.Conv2dConfig( dtype=self.model_config["ACTIVATIONS_DTYPE"], weights_dtype=self.model_config["WEIGHTS_DTYPE"], @@ -687,8 +682,8 @@ def __init__( input_channels_alignment=input_channels_alignment, act_block_h_override=act_block_h_override, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, - enable_split_reader=True if whb0_and_b16 or not is_wormhole_b0() else False, + enable_act_double_buffer=is_wormhole_b0() or is_blackhole(), + enable_split_reader=True, enable_subblock_padding=False, shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, reshard_if_not_optimal=False, @@ -696,9 +691,9 @@ def __init__( self.conv1_compute_config = ttnn.init_device_compute_kernel_config( device.arch(), math_fidelity=self.model_config["MATH_FIDELITY"], - packer_l1_acc=True if whb0_and_b16 else False, + packer_l1_acc=True, ) - if whb0_and_b16: + if is_wormhole_b0(): # Issue #13145: Temp workaround for Galaxy to avoid hangs if type(device) == ttnn.MeshDevice and device.get_num_devices() > 8: self.conv1_config.act_block_h_override = 64 @@ -706,6 +701,10 @@ def __init__( # Todo: restore after issue #16895 is fixed # self.conv1_config.act_block_h_override = 49 * 32 self.conv1_config.act_block_h_override = 2 * 32 + if is_blackhole(): + # self.conv1_config.act_block_h_override = 7 * 32 + # self.conv1_config.act_block_h_override = 2 * 32 + self.conv1_config.enable_split_reader = False self.conv1_kernel_size = (4, 4) self.conv1_stride = (1, 1) @@ -736,6 +735,8 @@ def __init__( w // self.fold_stride_w, C * (self.fold_stride_h * self.fold_stride_w), ) + num_cores_x = 8 + num_cores_y = 8 if self.batch_size == 16: num_cores_x = 8 num_cores_y = 8 @@ -743,9 +744,12 @@ def __init__( if is_grayskull(): num_cores_x = 10 num_cores_y = 8 - elif is_wormhole_b0(): # untested due to unsupported batch20 on WH + elif is_wormhole_b0(): num_cores_x = 8 num_cores_y = 5 + elif is_blackhole(): + num_cores_x = 8 + num_cores_y = 10 self.fold_compute_grid_size = (num_cores_x, num_cores_y) conv_dummy_tensor = torch.rand((self.fold_output_shape), dtype=torch.bfloat16) @@ -765,9 +769,8 @@ def __init__( ) def __del__(self): - # Need to clear global configs for each Resnet run - self.conv_op_cache.clear() - self.max_pool_reader_patterns_cache.clear() + # Nothing to do + pass def _make_layer( self, @@ -900,34 +903,42 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width = 56 x = ttnn.reshape(x, (1, 1, x_height * x_width * self.batch_size, 64)) - if is_wormhole_b0(): - # TODO: fix the need to do the reshard here + if is_blackhole(): + core_range_set = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(12, 7), + ), + ttnn.CoreRange( + ttnn.CoreCoord(0, 8), + ttnn.CoreCoord(7, 8), + ), + } + ) + elif is_wormhole_b0(): + core_range_set = ttnn.CoreGrid(x=8, y=7) + + if is_blackhole() or is_wormhole_b0(): mem_config = ttnn.create_sharded_memory_config_( ttnn.Shape([self.batch_size * x_height * x_width, 64]), - ttnn.CoreGrid(x=8, y=7), + core_range_set, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.ShardOrientation.ROW_MAJOR, tile_layout=True, ) x = ttnn.to_memory_config(x, mem_config) + x = ttnn.to_layout(x, ttnn.TILE_LAYOUT, dtype=self.model_config["ACTIVATIONS_DTYPE"]) - if self.batch_size == 20 and not is_wormhole_b0(): + if self.batch_size == 20 and is_grayskull(): x = ttnn.reallocate(x) logger.debug(f"==== Running layer 1 module 1") layer1_module1_input_shape = ttnn.Shape(x.padded_shape) reshard = False - height_shard = False - if is_wormhole_b0() and self.batch_size == 20: - if is_first_run: - reshard = True - height_shard = True - else: - x = ttnn.to_memory_config(x, ops_parallel_config["layer1_module1_input"]) - - whb0_and_b16 = is_wormhole_b0() and self.batch_size == 16 + height_shard = True x, x_height, x_width = self.layer1_module1( x, @@ -939,9 +950,9 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 else False, - enable_split_reader=True if whb0_and_b16 else False, - enable_subblock_padding=True if whb0_and_b16 else False, + enable_act_double_buffer=True, + enable_split_reader=True, + enable_subblock_padding=not is_grayskull(), ) if is_first_run: @@ -964,8 +975,8 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, transpose_shards=self.transpose_shards, enable_act_double_buffer=False, - enable_split_reader=True if whb0_and_b16 else False, - enable_subblock_padding=True if whb0_and_b16 else False, + enable_split_reader=True, + enable_subblock_padding=not is_grayskull(), ) logger.debug(f"==== Running layer 1 module 3") @@ -978,24 +989,37 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, transpose_shards=self.transpose_shards, enable_act_double_buffer=False, - enable_split_reader=True if whb0_and_b16 else False, - enable_subblock_padding=True if whb0_and_b16 else False, + enable_split_reader=True, + enable_subblock_padding=not is_grayskull(), ) - if self.batch_size == 20 and is_wormhole_b0(): - x = ttnn.reallocate(x) - layer2_module1_input_shape = ttnn.Shape(x.padded_shape) - reshard = False - height_shard = False - is_gs = is_grayskull() - if is_wormhole_b0() and self.batch_size == 20: - if is_first_run: - reshard = True if not is_wormhole_b0() else False - height_shard = True - else: - x = ttnn.to_memory_config(x, ops_parallel_config["layer2_module1_input"]) + reshard = not (is_wormhole_b0() or is_grayskull()) + height_shard = True + + if is_blackhole(): + ## 98 + core_range_set = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(12, 6), + ), + ttnn.CoreRange( + ttnn.CoreCoord(0, 7), + ttnn.CoreCoord(6, 7), + ), + } + ) + mem_config = ttnn.create_sharded_memory_config_( + layer2_module1_input_shape, + core_range_set, + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.ShardOrientation.ROW_MAJOR, + tile_layout=True, + ) + x = ttnn.to_memory_config(x, mem_config) logger.debug(f"==== Running layer 2 module 1") x, x_height, x_width = self.layer2_module1( @@ -1008,7 +1032,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1046,7 +1070,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1060,20 +1084,34 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) layer3_module1_input_shape = ttnn.Shape(x.padded_shape) - reshard = False + reshard = is_wormhole_b0() or is_grayskull() height_shard = False - if is_first_run: - reshard = True - height_shard = False - else: - x = ttnn.to_memory_config(x, ops_parallel_config["layer3_module1_input"]) + + if is_blackhole(): + ## 104 + core_range_set = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(12, 7), + ), + } + ) + mem_config = ttnn.create_sharded_memory_config_( + layer3_module1_input_shape, + core_range_set, + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ttnn.ShardOrientation.COL_MAJOR, + tile_layout=True, + ) + x = ttnn.to_memory_config(x, mem_config) logger.debug(f"==== Running layer 3 module 1") x, x_height, x_width = self.layer3_module1( @@ -1086,7 +1124,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1110,7 +1148,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1124,9 +1162,10 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, + layer_module="layer3_module3", ) logger.debug(f"==== Running layer 3 module 4") @@ -1138,9 +1177,10 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, + layer_module="layer3_module4", ) logger.debug(f"==== Running layer 3 module 5") @@ -1152,9 +1192,10 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, + layer_module="layer3_module5", ) logger.debug(f"==== Running layer 3 module 6") @@ -1167,36 +1208,44 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt conv_op_cache, eltwise_binary_out_in_place=True, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) - if is_wormhole_b0() and self.batch_size == 16: - xshape = x.shape - x = ttnn.slice( - x, starts=(0, 0, 0, 0), ends=(xshape[0], xshape[1], xshape[2], xshape[3]), steps=(1, 1, 1, 1) - ) + reshard = is_grayskull() + height_shard = False layer4_module1_input_shape = ttnn.Shape(x.padded_shape) - - if is_wormhole_b0(): + if is_blackhole(): + # 104 + grid_size = (13, 8) + core_range_set = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1), + ), + } + ) + mem_config = ttnn.create_sharded_memory_config_( + layer4_module1_input_shape, + core_range_set, + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ttnn.ShardOrientation.COL_MAJOR, + tile_layout=True, + ) + x = ttnn.to_memory_config(x, mem_config) + elif is_wormhole_b0(): + core_range_set = ttnn.CoreGrid(x=8, y=7) shard_config = ttnn.create_sharded_memory_config_( layer4_module1_input_shape, - ttnn.CoreGrid(x=8, y=7), + core_range_set, ttnn.TensorMemoryLayout.BLOCK_SHARDED, ttnn.ShardOrientation.ROW_MAJOR, tile_layout=True, ) x = ttnn.to_memory_config(x, shard_config) - else: - reshard = False - height_shard = False - if is_first_run: - reshard = True - height_shard = False - else: - x = ttnn.to_memory_config(x, ops_parallel_config["layer4_module1_input"]) logger.debug(f"==== Running layer 4 module 1") x, x_height, x_width = self.layer4_module1( @@ -1209,21 +1258,13 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt reshard_if_not_optimal=reshard, height_sharding=height_shard, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, + ops_parallel_config=ops_parallel_config, + layer_module="layer4_module1", ) - if is_first_run: - x_memory_config = ttnn.get_memory_config(x) - ops_parallel_config["layer4_module1_input"] = ttnn.create_sharded_memory_config_( - layer4_module1_input_shape, - x_memory_config.shard_spec.grid, - x_memory_config.memory_layout, - x_memory_config.shard_spec.orientation, - tile_layout=True, - ) - logger.debug(f"==== Running layer 4 module 2") x, x_height, x_width = self.layer4_module2( x, @@ -1233,7 +1274,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) @@ -1247,7 +1288,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt x_width, conv_op_cache, transpose_shards=self.transpose_shards, - enable_act_double_buffer=True if whb0_and_b16 or is_gs else False, + enable_act_double_buffer=True, enable_split_reader=False, enable_subblock_padding=False, ) diff --git a/models/demos/wormhole/resnet50/README.md b/models/demos/wormhole/resnet50/README.md index 4d0fed882ce..f7052631ff6 100644 --- a/models/demos/wormhole/resnet50/README.md +++ b/models/demos/wormhole/resnet50/README.md @@ -7,7 +7,7 @@ ResNet50 is a deep convolutional neural network architecture with 50 layers, des ## Details -+ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50_new_conv_api.py`. ++ The entry point to the Metal ResNet model is `ResNet` in `ttnn_functional_resnet50.py`. + The model picks up certain configs and weights from TorchVision pretrained model. We have used `torchvision.models.ResNet50_Weights.IMAGENET1K_V1` version from TorchVision as our reference. + Our ImageProcessor on the other hand is based on `microsoft/resnet-50` from huggingface. diff --git a/models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py b/models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py index 3402240f4ec..517e6d85cfe 100644 --- a/models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py +++ b/models/demos/wormhole/stable_diffusion/tests/test_cross_attn_up_block_2d.py @@ -2,31 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 -import torch from diffusers import StableDiffusionPipeline -from loguru import logger -import ttnn import pytest -from torch import nn - -from models.utility_functions import tt_to_torch_tensor, torch_random -from tests.ttnn.utils_for_testing import assert_with_pcc +import torch +import ttnn -from models.utility_functions import ( - skip_for_grayskull, -) +from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_cross_attn_upblock_new_conv import ( cross_attention_upblock2d, ) - -from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor - +from models.utility_functions import skip_for_grayskull, torch_random from ttnn.model_preprocessing import preprocess_model_parameters -from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_utility_functions import ( - pre_process_input, - weight_to_bfp8, - post_process_output, -) +from tests.ttnn.utils_for_testing import assert_with_pcc def ttnn_to_torch(input): @@ -36,15 +23,59 @@ def ttnn_to_torch(input): return input +def prepare_input_and_push_to_device(input, device, memory_config): + input = torch.permute(input, (0, 2, 3, 1)) + input = torch.reshape( + input, + ( + 1, + 1, + input.shape[0] * input.shape[1] * input.shape[2], + input.shape[3], + ), + ) + + input = ttnn.from_torch(input, ttnn.bfloat16) + input = ttnn.to_layout(input, ttnn.TILE_LAYOUT) + input = ttnn.to_dtype(input, ttnn.bfloat8_b) + return ttnn.to_device(input, device, memory_config=memory_config) + + @skip_for_grayskull() -@pytest.mark.skip(reason="#9599: Tests are failing.") @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.parametrize( - "hidden_states, res_hidden_states_tuple, index, prev_output_channel, in_channels ,out_channels", + "hidden_states, res_hidden_states_tuple, index, prev_output_channel, in_channels, out_channels, shard_end_core, shard_shape", [ - ((2, 1280, 16, 16), ([2, 640, 16, 16], [2, 1280, 16, 16], [2, 1280, 16, 16]), 1, 1280, 640, 1280), - ((2, 1280, 32, 32), ([2, 320, 32, 32], [2, 640, 32, 32], [2, 640, 32, 32]), 2, 1280, 320, 640), - ((2, 640, 64, 64), ([2, 320, 64, 64], [2, 320, 64, 64], [2, 320, 64, 64]), 3, 640, 320, 320), + ( + (2, 1280, 16, 16), + ([2, 640, 16, 16], [2, 1280, 16, 16], [2, 1280, 16, 16]), + 1, + 1280, + 640, + 1280, + (7, 3), + [128, 160], + ), + ( + (2, 1280, 32, 32), + ([2, 320, 32, 32], [2, 640, 32, 32], [2, 640, 32, 32]), + 2, + 1280, + 320, + 640, + (7, 7), + [256, 160], + ), + ( + (2, 640, 64, 64), + ([2, 320, 64, 64], [2, 320, 64, 64], [2, 320, 64, 64]), + 3, + 640, + 320, + 320, + (4, 7), + [1024, 128], + ), ], ) @pytest.mark.parametrize("temb", [[1, 1, 2, 1280]]) @@ -66,6 +97,8 @@ def test_cross_attn_up_block_2d_512x512( prev_output_channel, in_channels, out_channels, + shard_end_core, + shard_shape, ): # TODO # setup pytorch model @@ -73,7 +106,6 @@ def test_cross_attn_up_block_2d_512x512( unet = pipe.unet unet.eval() config = unet.config - state_dict = unet.state_dict() unet_upblock = pipe.unet.up_blocks[index] parameters = preprocess_model_parameters( @@ -122,36 +154,40 @@ def test_cross_attn_up_block_2d_512x512( cross_attention_kwargs = (None,) return_dict = True num_layers_transformer = 1 - norm_num_groups = 32 cross_attention_dim = 768 - attention_bias = False - sample_size = None - num_vector_embeds = None patch_size = None - activation_fn = "geglu" num_embeds_ada_norm = None use_linear_projection = False only_cross_attention = False upcast_attention = False norm_type = "layer_norm" - norm_elementwise_affine = True attn_num_head_channels = 8 - hidden_state = ttnn.from_torch(hidden_state, ttnn.bfloat16) - hidden_state = ttnn.to_layout(hidden_state, ttnn.TILE_LAYOUT) - hidden_state = ttnn.to_device(hidden_state, device, memory_config=ttnn.L1_MEMORY_CONFIG) - - res0 = ttnn.from_torch(res0, ttnn.bfloat16) - res0 = ttnn.to_layout(res0, ttnn.TILE_LAYOUT) - res0 = ttnn.to_device(res0, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - - res1 = ttnn.from_torch(res1, ttnn.bfloat16) - res1 = ttnn.to_layout(res1, ttnn.TILE_LAYOUT) - res1 = ttnn.to_device(res1, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + hidden_state = prepare_input_and_push_to_device( + hidden_state, + device, + ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.BLOCK_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(shard_end_core[0], shard_end_core[1]), + ), + } + ), + shard_shape, + ttnn.ShardOrientation.ROW_MAJOR, + ), + ), + ) - res2 = ttnn.from_torch(res2, ttnn.bfloat16) - res2 = ttnn.to_layout(res2, ttnn.TILE_LAYOUT) - res2 = ttnn.to_device(res2, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) + res0 = prepare_input_and_push_to_device(res0, device, ttnn.DRAM_MEMORY_CONFIG) + res1 = prepare_input_and_push_to_device(res1, device, ttnn.DRAM_MEMORY_CONFIG) + res2 = prepare_input_and_push_to_device(res2, device, ttnn.DRAM_MEMORY_CONFIG) + res_hidden_states_tuple = (res0, res1, res2) temb = temb.permute(2, 0, 1, 3) # pre-permute temb temb = ttnn.from_torch(temb, ttnn.bfloat16) @@ -166,12 +202,7 @@ def test_cross_attn_up_block_2d_512x512( add_upsample = True if index == 3: add_upsample = False - hidden_state = weight_to_bfp8(pre_process_input(device, hidden_state)) - res_hidden_states_tuple = ( - weight_to_bfp8(pre_process_input(device, res0)), - weight_to_bfp8(pre_process_input(device, res1)), - weight_to_bfp8(pre_process_input(device, res2)), - ) + op = model( hidden_state, res_hidden_states_tuple, @@ -180,7 +211,7 @@ def test_cross_attn_up_block_2d_512x512( out_channels, temb_channels, num_layers=3, - resnet_eps=1e-6, + resnet_eps=1e-5, resnet_time_scale_shift="default", resnet_act_fn="silu", resnet_groups=32, @@ -214,4 +245,4 @@ def test_cross_attn_up_block_2d_512x512( op = torch.reshape(op, (N, H * 2, W * 2, Cout)) op = op.permute(0, 3, 1, 2) - assert_with_pcc(torch_output, op, 0.92) + assert_with_pcc(torch_output, op, 0.91) diff --git a/models/utility_functions.py b/models/utility_functions.py index f945fb3c6bf..a1072d966bc 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -889,6 +889,10 @@ def skip_for_grayskull(reason_str="not working for Grayskull"): return pytest.mark.skipif(is_grayskull(), reason=reason_str) +def run_for_blackhole(reason_str="only runs for Blackhole"): + return pytest.mark.skipif(not is_blackhole(), reason=reason_str) + + def run_for_wormhole_b0(reason_str="only runs for Wormhole B0"): return pytest.mark.skipif(not is_wormhole_b0(), reason=reason_str) diff --git a/scripts/docker/build_docker_image.sh b/scripts/docker/build_docker_image.sh deleted file mode 100755 index e29af753ac0..00000000000 --- a/scripts/docker/build_docker_image.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -TT_METAL_DOCKERFILE="${1:-ubuntu-20.04-amd64}" -TT_METAL_DOCKER_IMAGE_TAG="${2:-$TT_METAL_DOCKERFILE}" - -TT_METAL_HOME=$(git rev-parse --show-toplevel) -( - cd ${TT_METAL_HOME} || exit - docker build -f dockerfile/${TT_METAL_DOCKERFILE}.Dockerfile -t ${TT_METAL_DOCKER_IMAGE_TAG} . -) diff --git a/scripts/docker/install_test_deps.sh b/scripts/docker/install_test_deps.sh deleted file mode 100755 index 88face8c94f..00000000000 --- a/scripts/docker/install_test_deps.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -# Check if two arguments are provided -if [ "$#" -ne 1 ]; then - echo "Usage: $0 " - exit 1 -fi - -DOXYGEN_VERSION=$1 - -# Install doxygen -mkdir -p /opt/tt_metal_infra/doxygen -wget -O /opt/tt_metal_infra/doxygen/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz "https://www.doxygen.nl/files/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz" -tar -xzf /opt/tt_metal_infra/doxygen/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz -C /opt/tt_metal_infra/doxygen/ -rm /opt/tt_metal_infra/doxygen/doxygen-${DOXYGEN_VERSION}.linux.bin.tar.gz -cd /opt/tt_metal_infra/doxygen/doxygen-${DOXYGEN_VERSION} -make install diff --git a/scripts/docker/requirements-20.04.txt b/scripts/docker/requirements-20.04.txt deleted file mode 100644 index 0f95b0da25a..00000000000 --- a/scripts/docker/requirements-20.04.txt +++ /dev/null @@ -1,23 +0,0 @@ -apt-utils -dialog -software-properties-common=0.99.9.12 -build-essential=12.8ubuntu1.1 -git -pandoc -libtbb-dev -libcapstone-dev -pkg-config -cmake -curl -wget -python3-pip -libhwloc-dev -libhdf5-serial-dev -ruby=1:2.7+1 -python3-dev=3.8.2-0ubuntu2 -python3.8-venv -cargo -ninja-build -patchelf -graphviz -bc diff --git a/scripts/docker/requirements-22.04.txt b/scripts/docker/requirements-22.04.txt deleted file mode 100644 index 6038b4e8f6a..00000000000 --- a/scripts/docker/requirements-22.04.txt +++ /dev/null @@ -1,22 +0,0 @@ -apt-utils -dialog -build-essential -gcc-12 -g++-12 -git -pandoc -libtbb-dev -libcapstone-dev -pkg-config -cmake -curl -wget -python3-pip -libhwloc-dev -python3-dev -python3-venv -cargo -ninja-build -libxml2-dev -libxslt-dev -bc diff --git a/scripts/docker/requirements_dev.txt b/scripts/docker/requirements_dev.txt deleted file mode 100644 index e7029ab3bc7..00000000000 --- a/scripts/docker/requirements_dev.txt +++ /dev/null @@ -1,9 +0,0 @@ -acl -emacs -jq -less -libmpfr-dev -nano -openssh-server -sudo -vim diff --git a/scripts/docker/run_docker_cmd.sh b/scripts/docker/run_docker_cmd.sh deleted file mode 100755 index cfa4a315451..00000000000 --- a/scripts/docker/run_docker_cmd.sh +++ /dev/null @@ -1,51 +0,0 @@ -#!/bin/bash - -set -ex - -docker_tag="ubuntu-20.04-amd64" -docker_opts= -docker_cmd= -# Function to display help -show_help() { - echo "Usage: $0 [-h] [-t docker tag] [-o docker launch options] -[c docker exec command]" - echo " -h Show this help message." - echo " -t Docker tag to run." - echo " -o Docker options." - echo " -c Docker exec command" -} - -while getopts "t:o:c:" opt; do - case ${opt} in - h ) - show_help - exit 0 - ;; - t ) - docker_tag="$OPTARG" - ;; - o ) - docker_opts="$OPTARG" - ;; - c ) - docker_cmd="$OPTARG" - ;; - \? ) - show_help - exit 1 - ;; - esac -done - -if [[ -z "${ARCH_NAME}" ]]; then - echo "Must provide ARCH_NAME in environment" 1>&2 - exit 1 -fi - -TT_METAL_HOME=$(git rev-parse --show-toplevel) - -source $TT_METAL_HOME/scripts/docker/build_docker_image.sh $docker_tag - -# Allows this script to be called anywhere in the tt-metal repo -source $TT_METAL_HOME/scripts/docker/run_docker_func.sh - -run_docker_common $docker_opts $docker_tag $docker_cmd diff --git a/scripts/docker/run_docker_func.sh b/scripts/docker/run_docker_func.sh deleted file mode 100755 index 98e993c886c..00000000000 --- a/scripts/docker/run_docker_func.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -set -e - - -GID=$(id -g "${USER}") - -if [[ -z "${TT_METAL_HOME}" ]]; then - TT_METAL_HOME=$(git rev-parse --show-toplevel) -fi - -function run_docker_common { - - docker run \ - --rm \ - -v ${TT_METAL_HOME}:/${TT_METAL_HOME} \ - -v /home:/home \ - -v /dev/hugepages-1G:/dev/hugepages-1G \ - -v /etc/group:/etc/group:ro \ - -v /etc/passwd:/etc/passwd:ro \ - -v /etc/shadow:/etc/shadow:ro \ - -w ${TT_METAL_HOME} \ - -e TT_METAL_HOME=${TT_METAL_HOME} \ - -e LOGURU_LEVEL=${LOGURU_LEVEL} \ - -e LD_LIBRARY_PATH=${LD_LIBRARY_PATH} \ - -e ARCH_NAME=${ARCH_NAME} \ - -e PYTHONPATH=${TT_METAL_HOME} \ - -e SILENT=${SILENT} \ - -e VERBOSE=${VERBOSE} \ - -u ${UID}:${GID} \ - --net host \ - "$1" \ - "$2:latest" \ - "$3" -} diff --git a/tech_reports/prog_examples/add_2_integers_in_compute/Tutorial_Add_Two_Integers_in_a_Compute_Kernel.md b/tech_reports/prog_examples/add_2_integers_in_compute/Tutorial_Add_Two_Integers_in_a_Compute_Kernel.md index 80f1279b093..61c569241b7 100644 --- a/tech_reports/prog_examples/add_2_integers_in_compute/Tutorial_Add_Two_Integers_in_a_Compute_Kernel.md +++ b/tech_reports/prog_examples/add_2_integers_in_compute/Tutorial_Add_Two_Integers_in_a_Compute_Kernel.md @@ -122,7 +122,7 @@ cb_push_back(cb_id_in1, 1); 12. Unpack, compute, and pack the data: ```binary_op_init_common(cb_in0, cb_in1, cb_out0); -add_tiles_init(); +add_tiles_init(cb_in0, cb_in1); // wait for a block of tiles in each of input CBs cb_wait_front(cb_in0, 1); diff --git a/tech_reports/prog_examples/add_2_integers_in_compute/add_2_integers_in_compute.md b/tech_reports/prog_examples/add_2_integers_in_compute/add_2_integers_in_compute.md index c150ad397b3..1af2b521a6c 100644 --- a/tech_reports/prog_examples/add_2_integers_in_compute/add_2_integers_in_compute.md +++ b/tech_reports/prog_examples/add_2_integers_in_compute/add_2_integers_in_compute.md @@ -157,7 +157,7 @@ The reader kernel reads in a one tile from each of the two source vectors that a ``` cpp binary_op_init_common(cb_in0, cb_in1, cb_out0); -add_tiles_init(); +add_tiles_init(cb_in0, cb_in1); // wait for a block of tiles in each of input CBs cb_wait_front(cb_in0, 1); diff --git a/tech_reports/prog_examples/shard_data_rm/shard_data_rm.md b/tech_reports/prog_examples/shard_data_rm/shard_data_rm.md index ba4ed41a58a..5a2a65bc725 100644 --- a/tech_reports/prog_examples/shard_data_rm/shard_data_rm.md +++ b/tech_reports/prog_examples/shard_data_rm/shard_data_rm.md @@ -60,12 +60,12 @@ uint32_t shard_size = shard_height * shard_width; uint32_t input_unit_size = sizeof(uint32_t); uint32_t shard_width_bytes = shard_width * data_size; uint32_t num_units_per_row = shard_width * input_unit_size; -uint32_t padded_offset_bytes = align(input_unit_size, device->allocator()->get_config().alignment); +uint32_t padded_offset_bytes = align(input_unit_size, device->allocator()->get_alignment(BufferType::L1)); ``` In order to shard the correct data segments to the respective core, we indicate the shard height, width, size, and other data for the kernel function. For this situation, 16 units of data will be sharded across 4 cores; each core will have 4 units of data in their corresponding circular buffer. -The `padded_offset_bytes` is set to ensure that the correct address is read from the kernel function when moving data to the circular buffer; in this case, the addresses are aligned to L1 memory. +The `padded_offset_bytes` is set to ensure that the correct address is read from the kernel function when moving data to the circular buffer; in this case, the addresses are aligned to L1 memory with explicit referencing to BufferType::L1. This example demonstrates height sharding; the shard height is therefore set to evenly distribute the number of vector values across the cores. If the sharding strategy was different (i.e. width sharding or block sharding), the appropriate values for both the shard height and width would need to be set. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 4942725477c..034ec2c7051 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,6 +11,7 @@ target_link_libraries( magic_enum fmt::fmt-header-only span + small_vector ) if(TT_METAL_BUILD_TESTS) diff --git a/tests/scripts/run_python_model_tests.sh b/tests/scripts/run_python_model_tests.sh index e5b1a8912f2..0290537e6e3 100755 --- a/tests/scripts/run_python_model_tests.sh +++ b/tests/scripts/run_python_model_tests.sh @@ -21,7 +21,7 @@ run_python_model_tests_grayskull() { pytest models/experimental/bert_large_performant/unit_tests/fused_ops/test_bert_large_fused_ln.py -k "in0_L1-out_L1 and batch_9" pytest models/experimental/bert_large_performant/unit_tests/fused_ops/test_bert_large_fused_softmax.py -k "in0_L1 and batch_9" - pytest tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py -k "pretrained_weight_false" + pytest tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py -k "pretrained_weight_false" # Falcon tests pytest models/demos/falcon7b_common/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM" pytest models/demos/falcon7b_common/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_512 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM" @@ -34,7 +34,7 @@ run_python_model_tests_wormhole_b0() { pytest models/demos/falcon7b_common/tests/unit_tests/test_falcon_attn_matmul.py -k "not attn_matmul_from_cache" # higher sequence lengths and different formats trigger memory issues pytest models/demos/falcon7b_common/tests/unit_tests/test_falcon_matmuls_and_bmms_with_mixed_precision.py -k "seq_len_128 and in0_BFLOAT16-in1_BFLOAT8_B-out_BFLOAT16-weights_DRAM" - pytest tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py -k "pretrained_weight_false" + pytest tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py -k "pretrained_weight_false" WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/yolov4/demo/demo.py -k "pretrained_weight_false" # Unet Shallow diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index b2112c7493e..87df13c964e 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -43,9 +43,15 @@ run_t3000_ttfabric_tests() { echo "LOG_METAL: Running run_t3000_ttfabric_tests" TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/tt_fabric/fabric_unit_tests --gtest_filter=ControlPlaneFixture.*T3k* + # Unicast tests TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 64 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 65 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 + # Line Mcast tests + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --e_depth 3 + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --w_depth 3 + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --n_depth 1 + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type t3k --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --s_depth 1 # Record the end time end_time=$(date +%s) diff --git a/tests/scripts/tg/run_tg_unit_tests.sh b/tests/scripts/tg/run_tg_unit_tests.sh index dac68fdd870..c82a51861b7 100755 --- a/tests/scripts/tg/run_tg_unit_tests.sh +++ b/tests/scripts/tg/run_tg_unit_tests.sh @@ -114,10 +114,15 @@ run_tg_tests() { elif [[ "$1" == "fabric" ]]; then echo "LOG_FABRIC: running run_tg_fabric_tests" TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/tt_fabric/fabric_unit_tests --gtest_filter=ControlPlaneFixture.*TG* + # Unicast tests TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 64 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 65 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 - + # Line Mcast tests + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --e_depth 7 + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --w_depth 7 + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --n_depth 3 + TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity_wormhole_b0 --fabric_command 1 --board_type glx32 --data_kb_per_tx 10 --num_src_endpoints 20 --num_dest_endpoints 8 --num_links 16 --s_depth 3 elif [[ "$1" == "llama3-70b" ]]; then run_tg_llama3.1-70b_tests diff --git a/tests/sweep_framework/sweeps/data_movement/interleaved_to_sharded/interleaved_to_sharded_e2e.py b/tests/sweep_framework/sweeps/data_movement/interleaved_to_sharded/interleaved_to_sharded_e2e.py index bcfefce3f7b..6d4122a7ef5 100644 --- a/tests/sweep_framework/sweeps/data_movement/interleaved_to_sharded/interleaved_to_sharded_e2e.py +++ b/tests/sweep_framework/sweeps/data_movement/interleaved_to_sharded/interleaved_to_sharded_e2e.py @@ -13,6 +13,7 @@ from models.utility_functions import torch_random TIMEOUT = 15 +TILE_HEIGHT = TILE_WIDTH = 32 # seed for random random.seed(0) @@ -22,7 +23,7 @@ {"shape": [1, 1, 1, 16], "shard_shape": None}, {"shape": [1, 1, 32, 16], "shard_shape": None}, {"shape": [1, 1, 16, 32], "shard_shape": None}, - {"shape": [1, 1, 32, 32], "shard_shape": None}, + {"shape": [1, 1, 128, 32], "shard_shape": None}, {"shape": [1, 1, 64, 64], "shard_shape": None}, {"shape": [1, 1, 128, 128], "shard_shape": None}, {"shape": [1, 1, 1, 16], "shard_shape": [1, 1, 1, 16]}, @@ -31,16 +32,17 @@ {"shape": [1, 1, 32, 32], "shard_shape": [1, 1, 16, 16]}, {"shape": [1, 1, 64, 64], "shard_shape": [1, 1, 16, 16]}, {"shape": [1, 1, 128, 128], "shard_shape": [1, 1, 32, 16]}, + {"shape": [1, 1, 128, 128], "shard_shape": [1, 1, 32, 32]}, ], "strategy": [ttnn.ShardStrategy.WIDTH, ttnn.ShardStrategy.HEIGHT], - "orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.ROW_MAJOR], + "orientation": [ttnn.ShardOrientation.ROW_MAJOR, ttnn.ShardOrientation.COL_MAJOR], "core_grid": [ ttnn.CoreGrid(y=1, x=1), ttnn.CoreGrid(y=2, x=1), ttnn.CoreGrid(y=1, x=2), ttnn.CoreGrid(y=2, x=2), ], - "dtype": [ttnn.bfloat16], + "dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "layout": [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT], "input_buffer_type": [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG], "output_buffer_type": [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG], @@ -55,7 +57,26 @@ def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: if test_vector["layout"] == ttnn.ROW_MAJOR_LAYOUT: if test_vector["dtype"] == ttnn.bfloat8_b: return True, "bfloat8_b not supported with ROW_MAJOR_LAYOUT" - + elif test_vector["layout"] == ttnn.TILE_LAYOUT: + if test_vector["shard_specs"]["shard_shape"] is not None and ( + test_vector["shard_specs"]["shard_shape"][-2] % TILE_HEIGHT != 0 + or test_vector["shard_specs"]["shard_shape"][-1] % TILE_WIDTH != 0 + ): + return True, "shard_shape not supported with TILE_LAYOUT" + elif test_vector["shard_specs"]["shard_shape"] is None: + ncores = test_vector["core_grid"].x * test_vector["core_grid"].y + sizey = ( + test_vector["shard_specs"]["shape"][-2] // ncores + if test_vector["strategy"] == ttnn.ShardStrategy.HEIGHT + else test_vector["shard_specs"]["shape"][-2] + ) + sizex = ( + test_vector["shard_specs"]["shape"][-1] // ncores + if test_vector["strategy"] == ttnn.ShardStrategy.WIDTH + else test_vector["shard_specs"]["shape"][-1] + ) + if sizex % TILE_HEIGHT != 0 or sizey % TILE_WIDTH != 0: + return True, "shard_shape not supported with TILE_LAYOUT" return False, None @@ -102,7 +123,7 @@ def run( device=device, layout=layout, memory_config=input_buffer_type, - dtype=ttnn.bfloat16, + dtype=dtype, ) # Measure performance of the split operation in ttnn diff --git a/tests/sweep_framework/sweeps/eltwise/unary/logit/logit.py b/tests/sweep_framework/sweeps/eltwise/unary/logit/logit.py index 5e581673762..3baa2df0d11 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/logit/logit.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/logit/logit.py @@ -14,10 +14,6 @@ from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time from models.utility_functions import torch_random -# Override the default timeout in seconds for hang detection. -TIMEOUT = 30 - -random.seed(0) # Parameters provided to the test vector generator are defined here. # They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. @@ -61,13 +57,13 @@ def run( *, device, ) -> list: - data_seed = random.randint(0, 20000000) - torch.manual_seed(data_seed) + torch.manual_seed(0) torch_input_tensor_a = gen_func_with_cast_tt( partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype )(input_shape) - torch_output_tensor = torch.logit(torch_input_tensor_a, eps) + golden_function = ttnn.get_golden_function(ttnn.logit) + torch_output_tensor = golden_function(torch_input_tensor_a, eps=eps, device=device) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, @@ -83,5 +79,4 @@ def run( e2e_perf = stop_measuring_time(start_time) pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.99) - # print(f"eps {eps} pcc {pcc}") return [pcc, e2e_perf] diff --git a/tests/tt_eager/integration_tests/test_bert.cpp b/tests/tt_eager/integration_tests/test_bert.cpp index 1066e8d04eb..54aed669997 100644 --- a/tests/tt_eager/integration_tests/test_bert.cpp +++ b/tests/tt_eager/integration_tests/test_bert.cpp @@ -230,7 +230,7 @@ void test_bert() { auto attention_mask = ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({batch_size, 1, TILE_HEIGHT, sequence_size}), Layout::TILE) - .to(device, l1_memory_config); + .to_device(device, l1_memory_config); auto parameters = Parameters{}; for (auto encoder_index = 0; encoder_index < num_encoders; encoder_index++) { @@ -238,74 +238,74 @@ void test_bert() { fmt::format("fused_qkv_weight_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, hidden_size, hidden_size * 3}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("fused_qkv_bias_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, hidden_size * 3}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("selfout_weight_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, hidden_size, hidden_size}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("selfout_bias_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, hidden_size}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("attention_layernorm_weight_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), Layout::ROW_MAJOR) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("attention_layernorm_bias_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), Layout::ROW_MAJOR) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("ff1_weight_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, hidden_size, intermediate_size}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("ff1_bias_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, intermediate_size}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("ff2_weight_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, intermediate_size, hidden_size}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("ff2_bias_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, hidden_size}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("feedforward_layernorm_weight_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), Layout::ROW_MAJOR) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( fmt::format("feedforward_layernorm_bias_{}", encoder_index), ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), Layout::ROW_MAJOR) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); }; parameters.emplace( "qa_head_weight", ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, hidden_size, TILE_WIDTH}), Layout::TILE) - .to(device, dram_memory_config)); + .to_device(device, dram_memory_config)); parameters.emplace( "qa_head_bias", ttnn::reshape( ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), Layout::TILE) - .to(device, dram_memory_config), + .to_device(device, dram_memory_config), ttnn::Shape({1, 1, 1, TILE_WIDTH}))); auto run_bert = [&]() { @@ -314,7 +314,7 @@ void test_bert() { auto hidden_states = ttnn::random::uniform( bfloat16(-1.0f), bfloat16(1.0f), ttnn::Shape({batch_size, 1, sequence_size, hidden_size}), Layout::TILE) - .to(device, l1_memory_config); + .to_device(device, l1_memory_config); for (auto encoder_index = 0; encoder_index < num_encoders; encoder_index++) { hidden_states = encoder(std::move(hidden_states), attention_mask, parameters, encoder_index, head_size); } diff --git a/tests/tt_eager/ops/test_bcast_op.cpp b/tests/tt_eager/ops/test_bcast_op.cpp index a91c50809fc..8913161cd05 100644 --- a/tests/tt_eager/ops/test_bcast_op.cpp +++ b/tests/tt_eager/ops/test_bcast_op.cpp @@ -49,7 +49,7 @@ int main(int argc, char** argv) { throw std::runtime_error("Unsupported Dim!"); } - Tensor a = ttnn::random::random(input_shape_a).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(input_shape_a).to_layout(Layout::TILE).to_device(device); Tensor b = ttnn::zeros( ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), DataType::BFLOAT16, Layout::TILE, *device); @@ -67,28 +67,28 @@ int main(int argc, char** argv) { } { - Tensor a = ttnn::random::random(Shape({1, 1, 32, 4544})).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(Shape({1, 1, 32, 4544})).to_layout(Layout::TILE).to_device(device); Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 4544}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::H); Tensor d = c.cpu(); } { - Tensor a = ttnn::random::random(Shape({1, 1, 32, 4544})).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(Shape({1, 1, 32, 4544})).to_layout(Layout::TILE).to_device(device); Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 4544}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::ADD, ttnn::BcastOpDim::H); Tensor d = c.cpu(); } { - Tensor a = ttnn::random::random(Shape({1, 71, 32, 32})).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(Shape({1, 71, 32, 32})).to_layout(Layout::TILE).to_device(device); Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 32}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::HW); Tensor d = c.cpu(); } { - Tensor a = ttnn::random::random(Shape({1, 71, 32, 64})).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(Shape({1, 71, 32, 64})).to_layout(Layout::TILE).to_device(device); Tensor b = ttnn::zeros(ttnn::Shape({1, 1, 32, 32}), DataType::BFLOAT16, Layout::TILE, *device); Tensor c = ttnn::bcast(0, a, b, ttnn::BcastOpMath::MUL, ttnn::BcastOpDim::HW); Tensor d = c.cpu(); diff --git a/tests/tt_eager/ops/test_bmm_op.cpp b/tests/tt_eager/ops/test_bmm_op.cpp index 5f369c747a7..b8c2b10d05b 100644 --- a/tests/tt_eager/ops/test_bmm_op.cpp +++ b/tests/tt_eager/ops/test_bmm_op.cpp @@ -40,7 +40,7 @@ int main(int argc, char** argv) { ttnn::Shape shapeb1({1, 1, Kt * TILE_HEIGHT, Nt * TILE_WIDTH}); // Allocates a DRAM buffer on device populated with values specified by initialize - Tensor a = ttnn::random::random(shapea).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(shapea).to_layout(Layout::TILE).to_device(device); Tensor b = ttnn::zeros(shapeb, DataType::BFLOAT16, Layout::TILE, *device); Tensor b1 = ttnn::zeros(shapeb1, DataType::BFLOAT16, Layout::TILE, *device); diff --git a/tests/tt_eager/ops/test_conv_prepare_weights_and_biases.cpp b/tests/tt_eager/ops/test_conv_prepare_weights_and_biases.cpp index 6bd886eb857..8dc88558494 100644 --- a/tests/tt_eager/ops/test_conv_prepare_weights_and_biases.cpp +++ b/tests/tt_eager/ops/test_conv_prepare_weights_and_biases.cpp @@ -463,7 +463,7 @@ static void test_convert_conv_bias_tensor_to_tiled_layout_block_sharded() { tt::log_info(tt::LogTest, "Running {}", __func__); for (auto i = 0; i < bias_tensor_shape.size(); i++) { auto input_tensor = - ttnn::random::random(Shape(bias_tensor_shape[i]), DataType::BFLOAT16).to(Layout::ROW_MAJOR).cpu(); + ttnn::random::random(Shape(bias_tensor_shape[i]), DataType::BFLOAT16).to_layout(Layout::ROW_MAJOR).cpu(); auto input_buffer = owned_buffer::get_as(input_tensor); auto output_tensor = ttnn::operations::conv::convert_conv_bias_tensor_to_tiled_layout_block_sharded( input_tensor, shards[i], DataType::BFLOAT16); diff --git a/tests/tt_eager/ops/test_eltwise_binary_op.cpp b/tests/tt_eager/ops/test_eltwise_binary_op.cpp index 4f6223692cc..32769454b8e 100644 --- a/tests/tt_eager/ops/test_eltwise_binary_op.cpp +++ b/tests/tt_eager/ops/test_eltwise_binary_op.cpp @@ -39,10 +39,11 @@ bool run_test(const ttnn::Shape& shape, const DeviceFunction& device_function, I auto input_tensor_b = ttnn::random::random(shape, DataType::BFLOAT16); auto host_output = HostFunction(input_tensor_a, input_tensor_b); - auto device_output = - device_function(input_tensor_a.to(Layout::TILE).to(device), input_tensor_b.to(Layout::TILE).to(device)) - .cpu() - .to(Layout::ROW_MAJOR); + auto device_output = device_function( + input_tensor_a.to_layout(Layout::TILE).to_device(device), + input_tensor_b.to_layout(Layout::TILE).to_device(device)) + .cpu() + .to_layout(Layout::ROW_MAJOR); return ttnn::allclose(host_output, device_output, args...); } @@ -111,8 +112,9 @@ int main() { run_binary_ops(); // Allocate a tensor to show that the addresses aren't cached - auto input_tensor = - ttnn::random::uniform(bfloat16(0.0f), bfloat16(0.0f), Shape({1, 1, 32, 32})).to(Layout::TILE).to(device); + auto input_tensor = ttnn::random::uniform(bfloat16(0.0f), bfloat16(0.0f), Shape({1, 1, 32, 32})) + .to_layout(Layout::TILE) + .to_device(device); run_binary_ops(); diff --git a/tests/tt_eager/ops/test_eltwise_unary_op.cpp b/tests/tt_eager/ops/test_eltwise_unary_op.cpp index 3370fd6d6a5..17839c2f228 100644 --- a/tests/tt_eager/ops/test_eltwise_unary_op.cpp +++ b/tests/tt_eager/ops/test_eltwise_unary_op.cpp @@ -57,42 +57,42 @@ Tensor host_function(const Tensor& input_tensor) { template bool run_test(IDevice* device, const ttnn::Shape& shape, float low, float high, Args... args) { - auto input_tensor = ttnn::random::uniform(bfloat16(low), bfloat16(high), shape).to(Layout::TILE); + auto input_tensor = ttnn::random::uniform(bfloat16(low), bfloat16(high), shape).to_layout(Layout::TILE); using ttnn::operations::unary::UnaryOpType; using ttnn::operations::unary::UnaryWithParam; if constexpr (unary_op_type == UnaryOpType::SQRT) { auto host_output = host_function<::detail::sqrt>(input_tensor); - auto device_output = ttnn::sqrt(input_tensor.to(device)).cpu(); + auto device_output = ttnn::sqrt(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::EXP) { auto host_output = host_function<::detail::exp>(input_tensor); - auto device_output = ttnn::exp(input_tensor.to(device)).cpu(); + auto device_output = ttnn::exp(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::RECIP) { auto host_output = host_function<::detail::recip>(input_tensor); - auto device_output = ttnn::reciprocal(input_tensor.to(device)).cpu(); + auto device_output = ttnn::reciprocal(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::GELU) { auto host_output = host_function<::detail::gelu>(input_tensor); - auto device_output = ttnn::gelu(input_tensor.to(device)).cpu(); + auto device_output = ttnn::gelu(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::RELU) { auto host_output = host_function<::detail::relu>(input_tensor); - auto device_output = ttnn::relu(input_tensor.to(device)).cpu(); + auto device_output = ttnn::relu(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::SIGMOID) { auto host_output = host_function<::detail::sigmoid>(input_tensor); - auto device_output = ttnn::sigmoid(input_tensor.to(device)).cpu(); + auto device_output = ttnn::sigmoid(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::LOG) { auto host_output = host_function<::detail::log>(input_tensor); - auto device_output = ttnn::log(input_tensor.to(device)).cpu(); + auto device_output = ttnn::log(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } else if constexpr (unary_op_type == UnaryOpType::TANH) { auto host_output = host_function<::detail::tanh>(input_tensor); - auto device_output = ttnn::tanh(input_tensor.to(device)).cpu(); + auto device_output = ttnn::tanh(input_tensor.to_device(device)).cpu(); return ttnn::allclose(host_output, device_output, args...); } TT_ASSERT(false, "Unsupported function"); @@ -110,7 +110,8 @@ void test_operation_infrastructure() { auto device = tt::tt_metal::CreateDevice(device_id); auto shape = ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}); - auto input_tensor = ttnn::random::uniform(bfloat16(0), bfloat16(1), shape).to(Layout::TILE).to(device); + auto input_tensor = + ttnn::random::uniform(bfloat16(0), bfloat16(1), shape).to_layout(Layout::TILE).to_device(device); ttnn::operations::unary::operation_attributes_t op_args{ {UnaryWithParam{UnaryOpType::SQRT}}, @@ -142,8 +143,8 @@ void test_shape_padding() { auto padded_input_tensor = ttnn::pad(input_tensor, padded_input_shape, tt::tt_metal::Array4D({0, 0, 0, 0}), 0); - padded_input_tensor = padded_input_tensor.to(Layout::TILE); - padded_input_tensor = padded_input_tensor.to(device); + padded_input_tensor = padded_input_tensor.to_layout(Layout::TILE); + padded_input_tensor = padded_input_tensor.to_device(device); auto output_tensor = ttnn::sqrt(padded_input_tensor); output_tensor = output_tensor.cpu(); @@ -250,8 +251,8 @@ void test_program_cache() { // Allocate a tensor to show that the addresses aren't cached auto input_tensor = ttnn::random::uniform(bfloat16(0.0f), bfloat16(0.0f), ttnn::Shape({1, 1, 32, 32})) - .to(Layout::TILE) - .to(device); + .to_layout(Layout::TILE) + .to_device(device); // Program Cache Hit run_test(device, ttnn::Shape({1, 1, TILE_HEIGHT, TILE_WIDTH}), 0.0f, 1.0f, 1e-1f, 1e-5f); diff --git a/tests/tt_eager/ops/test_fold_op.cpp b/tests/tt_eager/ops/test_fold_op.cpp index e8e69512121..0d8129a2155 100644 --- a/tests/tt_eager/ops/test_fold_op.cpp +++ b/tests/tt_eager/ops/test_fold_op.cpp @@ -16,7 +16,7 @@ using namespace tt::tt_metal; using namespace constants; void run_fold(IDevice* device, const ttnn::Shape& shape) { - Tensor input_tensor = ttnn::random::random(shape).to(Layout::ROW_MAJOR).to(device); + Tensor input_tensor = ttnn::random::random(shape).to_layout(Layout::ROW_MAJOR).to_device(device); uint32_t stride_h = 2; uint32_t stride_w = 2; uint8_t queue_id = 0; diff --git a/tests/tt_eager/ops/test_layernorm_op.cpp b/tests/tt_eager/ops/test_layernorm_op.cpp index ce51ff89e09..b1d775aa1af 100644 --- a/tests/tt_eager/ops/test_layernorm_op.cpp +++ b/tests/tt_eager/ops/test_layernorm_op.cpp @@ -29,7 +29,7 @@ int main(int argc, char** argv) { int device_id = 0; tt_metal::IDevice* device = tt_metal::CreateDevice(device_id); ttnn::Shape shape({1, 1, TILE_HEIGHT, TILE_WIDTH}); - Tensor a = ttnn::random::random(shape).to(Layout::TILE).to(device); + Tensor a = ttnn::random::random(shape).to_layout(Layout::TILE).to_device(device); Tensor c = ttnn::layer_norm(a, 1e-4f); Tensor d = c.cpu(); Tensor host_a = a.cpu(); // Move tensor a to host to validate diff --git a/tests/tt_eager/ops/test_sliding_window_ops.cpp b/tests/tt_eager/ops/test_sliding_window_ops.cpp index 16f98f5c1bf..67a8e78b3db 100644 --- a/tests/tt_eager/ops/test_sliding_window_ops.cpp +++ b/tests/tt_eager/ops/test_sliding_window_ops.cpp @@ -381,9 +381,9 @@ int main() { ttnn::Shape filter_tensor_shape({config.window_hw.first, config.window_hw.second}); Tensor input_padded_tensor = - ttnn::random::random(input_tensor_shape, DataType::BFLOAT16).to(Layout::ROW_MAJOR).cpu(); + ttnn::random::random(input_tensor_shape, DataType::BFLOAT16).to_layout(Layout::ROW_MAJOR).cpu(); Tensor filter_tensor = - ttnn::random::random(filter_tensor_shape, DataType::BFLOAT16).to(Layout::ROW_MAJOR).cpu(); + ttnn::random::random(filter_tensor_shape, DataType::BFLOAT16).to_layout(Layout::ROW_MAJOR).cpu(); auto input_padded_tensor_buf = owned_buffer::get_as(input_padded_tensor); auto filter_tensor_buf = owned_buffer::get_as(filter_tensor); diff --git a/tests/tt_eager/ops/test_softmax_op.cpp b/tests/tt_eager/ops/test_softmax_op.cpp index 58679e8a116..f064b603b26 100644 --- a/tests/tt_eager/ops/test_softmax_op.cpp +++ b/tests/tt_eager/ops/test_softmax_op.cpp @@ -16,7 +16,7 @@ using namespace tt::tt_metal; using namespace constants; void run_softmax(IDevice* device, const ttnn::Shape& shape) { - Tensor input_tensor = ttnn::random::random(shape).to(Layout::TILE).to(device); + Tensor input_tensor = ttnn::random::random(shape).to_layout(Layout::TILE).to_device(device); Tensor device_output_tensor = ttnn::softmax_in_place(input_tensor); Tensor output_tensor = device_output_tensor.cpu(); } diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py index 0df078da76d..505eef2d275 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_untilize_with_unpadding.py @@ -21,22 +21,53 @@ def create_grid(x, y): params = [ - pytest.param([[5, 5, 32, 32]], untilize_with_unpadding_args) - for untilize_with_unpadding_args in generation_funcs.gen_untilize_with_unpadding_args([[5, 5, 32, 32]]) + pytest.param( + [[5, 5, 32, 32]], + { + "dtype": [ttnn.bfloat16], + "layout": [ttnn.TILE_LAYOUT], + "input_mem_config": [ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM)], + "output_mem_config": ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), + "output_tensor_end": [4, 4, 31, 28], + }, + ) ] + params += [ - pytest.param([[5, 5, 64, 96]], untilize_with_unpadding_args) - for untilize_with_unpadding_args in generation_funcs.gen_untilize_with_unpadding_args([[5, 5, 64, 96]]) + pytest.param( + [[5, 5, 64, 96]], + { + "dtype": [ttnn.bfloat16], + "layout": [ttnn.TILE_LAYOUT], + "input_mem_config": [ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM)], + "output_mem_config": ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + "output_tensor_end": [4, 4, 60, 90], + }, + ) ] params += [ pytest.param( - [[1, 1, 128, 7328]], + [[5, 5, 64, 96]], { "dtype": [ttnn.bfloat16], "layout": [ttnn.TILE_LAYOUT], - "input_mem_config": [ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM)], + "input_mem_config": [ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1)], "output_mem_config": ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + "output_tensor_end": [4, 4, 60, 90], + }, + ) +] + + +params += [ + pytest.param( + [[1, 1, 128, 7328]], + { + "dtype": [ttnn.bfloat16], + "layout": [ttnn.TILE_LAYOUT], + "input_mem_config": [ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1)], + "output_mem_config": ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), "output_tensor_end": [0, 0, 119, 7299], }, ) diff --git a/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py b/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py deleted file mode 100644 index 7784996ce58..00000000000 --- a/tests/tt_eager/python_api_testing/trace_testing/misc/test_average_pool.py +++ /dev/null @@ -1,111 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest -from loguru import logger - -import torch - -import ttnn - -from tt_lib.utils import _nearest_32 -from models.utility_functions import comp_pcc -import ttnn - -TILE_HEIGHT = TILE_WIDTH = 32 - - -def shape_padded(shape): - return [shape[0], shape[1], _nearest_32(shape[2]), _nearest_32(shape[3])] - - -@pytest.mark.parametrize( - "act_shape", - ( - pytest.param([1, 7, 7, 2048]), - ([1, 1, 32, 64]), - ), - ids=["resnet50_unpadded", "tile_divisible"], -) -@pytest.mark.parametrize( - "dtype", - (ttnn.bfloat16,), - ids=[ - "BFLOAT16", - ], -) -@pytest.mark.parametrize("enable_async", [True, False]) -@pytest.mark.parametrize("device_params", [{"trace_region_size": 11264}], indirect=True) -def test_run_average_pool(act_shape, dtype, device, use_program_cache, enable_async): - device.enable_async(enable_async) - - batch_size, _, _, channels = act_shape - - torch.manual_seed(0) - - interleaved_mem_config_L1 = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.L1, - ) - - trace_loops = 10 - - out_shape = [1] * len(act_shape) - out_shape[-1] = act_shape[-1] - out_shape_padded = shape_padded(out_shape) - - act = torch.randn(act_shape, dtype=torch.bfloat16).float() - ttact = ttnn.Tensor(act, ttnn.bfloat16) - act_shape_padded = shape_padded(act_shape) - if act_shape != act_shape_padded: - ttact = ttact.pad_to_tile(0.0) - - ttact_res = ttact.to(device) - - def run_ops(ttact_res): - return ttnn.global_avg_pool2d(ttact_res) - - # Compile - run_ops(ttact_res) - # Trace - logger.info("Start Trace capture") - tid = ttnn.begin_trace_capture(device, cq_id=0) - out_res = run_ops(ttact_res) - ttnn.end_trace_capture(device, tid, cq_id=0) - logger.info("Trace captured") - - for iter in range(trace_loops): - act = torch.randn(act_shape, dtype=torch.bfloat16).float() - ttact_updated = ttnn.Tensor(act, ttnn.bfloat16) - act_shape_padded = shape_padded(act_shape) - if act_shape != act_shape_padded: - ttact_updated = ttact_updated.pad_to_tile(0.0) - ttnn.copy_host_to_device_tensor(ttact_updated, ttact_res) - - logger.info(f"Running iteration {iter}") - ttnn.execute_trace(device, tid, cq_id=0, blocking=True) - - out = out_res.cpu().to(ttnn.ROW_MAJOR_LAYOUT) - out_shape = [batch_size, 1, 1, channels] - out_shape_padded = shape_padded(out_shape) - if out_shape != out_shape_padded: - out = out.unpad_from_tile(out_shape) - - out_pytorch = out.to_torch() - out = out.pad_to_tile(0) # Undo, so next loop unpad_from_tile works again. - - ## reference - act_channels_first = torch.permute(act, (0, 3, 1, 2)) # Torch operates on channels-first tensors - golden_pytorch = torch.nn.AdaptiveAvgPool2d((1, 1))(act_channels_first) - - ## test for equivalance - passing_pcc, output_pcc = comp_pcc(golden_pytorch, out_pytorch) - logger.debug(f"Passing PCC = {passing_pcc}") - logger.debug(f"Output PCC = {output_pcc}") - - assert passing_pcc - - # Done with the trace, can deallocate the buffers now. - ttnn.release_trace(device, tid) - device.enable_async(False) diff --git a/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py b/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py deleted file mode 100644 index 634eacdc740..00000000000 --- a/tests/tt_eager/python_api_testing/trace_testing/misc/test_bert_ops.py +++ /dev/null @@ -1,257 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import pytest -import torch -import math -import ttnn - - -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( - comp_pcc, -) -from models.utility_functions import is_wormhole_b0, is_grayskull, is_wormhole_b0, is_blackhole -from loguru import logger -from models.utility_functions import torch2tt_tensor, tt2torch_tensor, pad_by_zero - - -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported parallelizations for WH B0 and BH") -@pytest.mark.parametrize("fidelity", [ttnn.MathFidelity.LoFi, ttnn.MathFidelity.HiFi2], ids=["LoFi", "HiFi2"]) -@pytest.mark.parametrize( - "in1_in_dram, out_sharded, in0_sharded, M, K, N, activation", - [ - # (False, True, True, 12*128, 1024, 1024, None), - # (False, True, True, 12*128, 4096, 1024, None), - # (False, True, True, 12*128, 8192, 1024, None), - # one core - # (False, False, False, 128, 256, 128, None), - # # in1-L1-fusedQKV - (False, True, True, 4608, 1024, 3072, None), # both sharded - (False, True, False, 4608, 1024, 3072, None), # out sharded, in0 interleaved - (False, False, True, 4608, 1024, 3072, None), # out interleaved, in0 sharded - (False, False, False, 4608, 1024, 3072, None), # out interleaved, in0 interleaved - ], -) -@pytest.mark.parametrize("enable_async", [True, False]) -class TestBertOpsTrace: - # TODO: Not all ops here take in cq id, only works with 0 for now - def run_bert_linear( - self, - device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - M, - K, - N, - activation, - enable_async, - cq_id, - ): - device.enable_async(enable_async) - has_bias = False - in0_shape = [1, 1, M, K] - in1_shape = [1, 1, K, N] - bias_shape = [1, 1, N] - out_shape = [1, 1, M, N] - grid_size = (12, 8) - # grid_size = (2, 2) - shard_shape = [M // grid_size[0], K // grid_size[1]] # shard height, width - - in0_block_w = K // grid_size[1] // 32 # 16 - in0_block_h = M // grid_size[0] // 32 - out_block_h = M // grid_size[0] // 32 - out_block_w = N // grid_size[1] // 32 - - if out_block_w <= 8: - out_subblock_w = out_block_w - out_subblock_h = 8 // out_subblock_w - else: - out_subblock_h = 1 - out_subblock_w = 8 // out_subblock_h - while out_block_w % out_subblock_w != 0: - out_subblock_w = out_block_w // 2 - - # in0_block_w = K // grid_size[1] // 32 - # out_subblock_w = 4 - # out_subblock_h = 4 - - logger.debug("in0 block w h " + str(in0_block_w * 32) + " " + str(in0_block_h * 32)) - logger.debug("in1 block w h " + str(out_block_w * 32) + " " + str(in0_block_w * 32)) - logger.debug("out block w h " + str(out_block_w * 32) + " " + str(out_block_h * 32)) - logger.debug("out subblock w h " + str(out_subblock_w * 32) + " " + str(out_subblock_h * 32)) - - interleaved_mem_config_L1 = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.L1, - ) - interleaved_mem_config_DRAM = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.DRAM, - ) - sharded_mem_config = ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.BLOCK_SHARDED, - buffer_type=ttnn.BufferType.L1, - ) - - in0 = torch.randn(in0_shape).bfloat16().float() - in1 = torch.randn(in1_shape).bfloat16().float() - bias = torch.randn(bias_shape).bfloat16().float() - in0_t_res = torch2tt_tensor(in0, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttnn.bfloat8_b) - - if in1_in_dram: - in1_t = torch2tt_tensor(in1, device, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttnn.bfloat8_b) - else: - in1_t = torch2tt_tensor(in1, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttnn.bfloat8_b) - - output_mem_config = sharded_mem_config if out_sharded else interleaved_mem_config_L1 - - bias_t = pad_by_zero(bias, device, tt_memory_config=interleaved_mem_config_L1, tt_dtype=ttnn.bfloat8_b)[0] - - program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig( - compute_with_storage_grid_size=grid_size, - in0_block_w=in0_block_w, - out_subblock_h=out_subblock_h, - out_subblock_w=out_subblock_w, - per_core_M=out_block_h, - per_core_N=out_block_w, - transpose_mcast=True, - # transpose_mcast=False, - fused_activation=activation, - ) - - compute_kernel_config = ttnn.GrayskullComputeKernelConfig(math_fidelity=fidelity, math_approx_mode=True) - - trace_loops = 4 - - def run_ops(in0_t_res): - if in0_sharded: - in0_t = ttnn.interleaved_to_sharded( - in0_t_res, - grid_size, - [M // grid_size[0], K // grid_size[1]], - ttnn.TensorMemoryLayout.BLOCK_SHARDED, - ttnn.ShardOrientation.COL_MAJOR, - ) - else: - in0_t = ttnn.clone(in0_t_res, memory_config=interleaved_mem_config_L1) - - if has_bias: - output_t = ttnn.linear( - in0_t, - in1_t, - bias=bias_t, - program_config=program_config, - memory_config=output_mem_config, - compute_kernel_config=compute_kernel_config, - ) - else: - output_t = ttnn.matmul( - in0_t, - in1_t, - program_config=program_config, - memory_config=output_mem_config, - compute_kernel_config=compute_kernel_config, - ) - if out_sharded: - output_t = ttnn.sharded_to_interleaved(output_t, interleaved_mem_config_L1) - return output_t - - # Compile - run_ops(in0_t_res) - # Capture - logger.info("Start Trace capture") - tid = ttnn.begin_trace_capture(device, cq_id=cq_id) - output_t_res = run_ops(in0_t_res) - ttnn.end_trace_capture(device, tid, cq_id=cq_id) - logger.info("Trace captured") - - for iter in range(trace_loops): - in0 = torch.randn(in0_shape).bfloat16().float() - in0_t_updated = torch2tt_tensor( - in0, None, tt_memory_config=interleaved_mem_config_DRAM, tt_dtype=ttnn.bfloat8_b - ) - ttnn.copy_host_to_device_tensor(in0_t_updated, in0_t_res) - logger.info(f"Running iteration {iter}") - ttnn.execute_trace(device, tid, cq_id=cq_id, blocking=True) - - pt_out = in0 @ in1 - - if has_bias: - pt_out = pt_out + bias - - if activation != None: - pt_out = torch.nn.functional.gelu(pt_out) - tt_out = tt2torch_tensor(output_t_res) - - passing, output = comp_pcc(pt_out, tt_out) - logger.info(output) - assert passing - - # Done with the trace, can deallocate the buffers now. - ttnn.release_trace(device, tid) - device.enable_async(False) - - @pytest.mark.parametrize("device_params", [{"trace_region_size": 34816}], indirect=True) - def test_bert_linear_1cq_initialized( - self, - device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - M, - K, - N, - activation, - use_program_cache, - function_level_defaults, - enable_async, - ): - self.run_bert_linear( - device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - M, - K, - N, - activation, - enable_async, - 0, - ) - - @pytest.mark.parametrize("cq_id", [0]) - @pytest.mark.parametrize("device_params", [{"trace_region_size": 34816, "num_command_queues": 2}], indirect=True) - def test_bert_linear_2cqs_initialized( - self, - device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - M, - K, - N, - activation, - use_program_cache, - function_level_defaults, - enable_async, - cq_id, - ): - self.run_bert_linear( - device, - fidelity, - in0_sharded, - out_sharded, - in1_in_dram, - M, - K, - N, - activation, - enable_async, - cq_id, - ) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py deleted file mode 100644 index cec1659bca5..00000000000 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_repeat.py +++ /dev/null @@ -1,163 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import math -from pathlib import Path -import sys - -import torch - -import ttnn -from models.utility_functions import print_diff_argmax, is_blackhole -import pytest -from loguru import logger - -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( - comp_pcc, - comp_equal, -) - - -def run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, output_mem_config): - if layout == ttnn.ROW_MAJOR_LAYOUT and dtype == ttnn.bfloat8_b: - pytest.skip("Illegal config") - rm_last_dim_repeat_on_bh = layout == ttnn.ROW_MAJOR_LAYOUT and is_blackhole() and repeats[-1] != 1 - if layout == ttnn.TILE_LAYOUT or rm_last_dim_repeat_on_bh: - alignment = 64 if is_blackhole() else 32 - if rm_last_dim_repeat_on_bh and input_shape[-1] % alignment != 0: - pytest.skip(f"Illegal config for BH see #14518") - elif (input_shape[-2] % alignment != 0 or input_shape[-1] % alignment != 0) and layout == ttnn.TILE_LAYOUT: - pytest.skip("Illegal config") - input = torch.rand(input_shape).to(torch.bfloat16) - tt_input = ( - ttnn.Tensor( - input, - dtype, - ) - .to(layout) - .to(device, input_mem_config) - ) - - tt_cpu = input.repeat(torch.Size(repeats)) - - tt = ttnn.repeat(tt_input, ttnn.Shape(repeats), memory_config=output_mem_config) - - tt_dev = tt.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - if dtype == ttnn.bfloat8_b: - passing, output = comp_pcc(tt_cpu, tt_dev) - else: - passing, output = comp_equal(tt_cpu, tt_dev) - logger.info(output) - assert passing - - -@pytest.mark.parametrize( - "input_shape, repeats", - ( - ((1, 2, 64, 64), [1, 1, 1, 1]), - ((1, 1, 64, 64), [1, 1, 1, 2]), - ((1, 1, 32, 128), [5, 3, 4, 2]), - ((2, 4, 32, 1280), [3, 1, 1, 5]), - ((1, 1, 32, 16), [1, 1, 1, 2048]), - ), -) -@pytest.mark.parametrize( - "layout, dtype", - ( - (ttnn.TILE_LAYOUT, ttnn.bfloat16), - (ttnn.TILE_LAYOUT, ttnn.bfloat8_b), - (ttnn.ROW_MAJOR_LAYOUT, ttnn.bfloat16), - ), -) -@pytest.mark.parametrize( - "input_mem_config", - ( - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.DRAM, - ), - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.L1, - ), - ), -) -@pytest.mark.parametrize( - "output_mem_config", - ( - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.DRAM, - ), - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.L1, - ), - ), -) -def test_repeat( - input_shape, repeats, device, layout, dtype, input_mem_config, output_mem_config, function_level_defaults -): - run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, output_mem_config) - - -@pytest.mark.parametrize( - "input_shape, repeats", - ( - ((1, 2, 64, 64), [1, 1, 1, 1]), - ((1, 1, 64, 64), [1, 1, 1, 2]), - ((1, 1, 32, 128), [5, 3, 4, 2]), - ((2, 4, 32, 1280), [3, 1, 1, 5]), - ((1, 1, 32, 16), [1, 1, 1, 2048]), - ), -) -@pytest.mark.parametrize( - "layout, dtype", - ( - (ttnn.TILE_LAYOUT, ttnn.bfloat16), - (ttnn.TILE_LAYOUT, ttnn.bfloat8_b), - (ttnn.ROW_MAJOR_LAYOUT, ttnn.bfloat16), - ), -) -@pytest.mark.parametrize( - "input_mem_config", - ( - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.DRAM, - ), - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.L1, - ), - ), -) -@pytest.mark.parametrize( - "output_mem_config", - ( - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.DRAM, - ), - ttnn.MemoryConfig( - memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, - buffer_type=ttnn.BufferType.L1, - ), - ), -) -def test_repeat_with_program_cache( - input_shape, - repeats, - device, - layout, - dtype, - input_mem_config, - output_mem_config, - use_program_cache, - function_level_defaults, -): - run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, output_mem_config) - tmp = ttnn.zeros([1, 256, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device) - run_repeat(input_shape, repeats, device, layout, dtype, input_mem_config, output_mem_config) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py index 4d3f555a8ad..d123cec54f9 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded.py @@ -164,7 +164,7 @@ def test_sharded_rm( ), ) - yt = ttnn.interleaved_to_sharded(xt, grid_size, shard_size, shard_scheme, shard_orientation) + yt = ttnn.interleaved_to_sharded(xt, grid_size, shard_size, shard_scheme, shard_orientation, keep_l1_aligned=True) zt = ttnn.sharded_to_interleaved( yt, @@ -172,6 +172,7 @@ def test_sharded_rm( memory_layout=ttnn.TensorMemoryLayout.INTERLEAVED, buffer_type=ttnn.BufferType.L1, ), + is_l1_aligned=True, ) tt_og = xt.cpu().to_torch() diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index 8f8624dd480..884160d86c3 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -61,8 +61,8 @@ TEST_F(DispatchFixture, TestTensorOwnershipSanity) { host_tensor.get_storage()); // Send tensor to device, read it back and copy it to empty tensor initialized by main thread Tensor reshaped_tensor = ttnn::experimental::view(host_tensor, ttnn::Shape{1, 1, 32, 128}); - auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); - auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR); + auto device_tensor = reshaped_tensor.to_layout(Layout::TILE).to_device(device); + auto thread_local_tensor = device_tensor.cpu().to_layout(Layout::ROW_MAJOR); readback_tensor.set_storage(thread_local_tensor.get_storage()); readback_tensor.set_tensor_spec(thread_local_tensor.get_tensor_spec()); readback_tensor.tensor_attributes->metadata_populated = true; @@ -292,8 +292,8 @@ TEST_F(DispatchFixture, TestTensorAsyncDataMovement) { host_tensor.get_storage()); Tensor reshaped_tensor = ttnn::experimental::view(host_tensor, ttnn::Shape{1, 1, 32, tensor_stop / 32}); - auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); - auto thread_local_tensor = device_tensor.cpu().to(Layout::ROW_MAJOR); + auto device_tensor = reshaped_tensor.to_layout(Layout::TILE).to_device(device); + auto thread_local_tensor = device_tensor.cpu().to_layout(Layout::ROW_MAJOR); log_info(LogTest, "Worker populating empty host readback_tensor"); readback_tensor.set_storage(thread_local_tensor.get_storage()); readback_tensor.set_tensor_spec(thread_local_tensor.get_tensor_spec()); diff --git a/tests/tt_eager/tensors/test_copy_and_move.cpp b/tests/tt_eager/tensors/test_copy_and_move.cpp index 82b040be944..551460d0f89 100644 --- a/tests/tt_eager/tensors/test_copy_and_move.cpp +++ b/tests/tt_eager/tensors/test_copy_and_move.cpp @@ -22,14 +22,14 @@ bool test_tensor_copy_semantics(IDevice* device) { ttnn::Shape single_tile_shape({1, 1, TILE_HEIGHT, TILE_WIDTH}); // host tensor to host tensor copy constructor - Tensor host_a = ttnn::random::random(single_tile_shape).to(Layout::TILE); + Tensor host_a = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE); Tensor host_a_copy = host_a; auto host_a_data = owned_buffer::get_as(host_a); auto host_a_copy_data = owned_buffer::get_as(host_a_copy); pass &= host_a_data == host_a_copy_data; // dev tensor to dev tensor copy constructor - Tensor dev_a = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device); + Tensor dev_a = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device); Tensor dev_a_copy = dev_a; auto dev_a_on_host = dev_a.cpu(); auto dev_a_copy_on_host = dev_a_copy.cpu(); @@ -40,15 +40,15 @@ bool test_tensor_copy_semantics(IDevice* device) { // host tensor updated with host tensor copy assignment Tensor host_c = ttnn::experimental::view( ttnn::arange(/*start=*/0, /*stop=*/single_tile_shape.volume(), /*step=*/1), single_tile_shape) - .to(Layout::TILE); - Tensor host_c_copy = ttnn::random::random(single_tile_shape).to(Layout::TILE); + .to_layout(Layout::TILE); + Tensor host_c_copy = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE); host_c_copy = host_c; auto host_c_data = owned_buffer::get_as(host_c); auto host_c_copy_data = owned_buffer::get_as(host_c_copy); pass &= host_c_data == host_c_copy_data; // host tensor updated with dev tensor copy assignment - Tensor host_d_copy = ttnn::random::random(single_tile_shape).to(Layout::TILE); + Tensor host_d_copy = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE); host_d_copy = dev_a; pass &= (host_d_copy.storage_type() == StorageType::DEVICE); auto host_d_copy_on_host = host_d_copy.cpu(); @@ -57,7 +57,7 @@ bool test_tensor_copy_semantics(IDevice* device) { // dev tensor updated with host tensor copy assignment Tensor host_e = ttnn::ones(single_tile_shape, DataType::BFLOAT16, Layout::TILE); - Tensor dev_e_copy = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device); + Tensor dev_e_copy = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device); dev_e_copy = host_e; pass &= (dev_e_copy.storage_type() == StorageType::OWNED); auto host_e_data = owned_buffer::get_as(host_e); @@ -92,7 +92,8 @@ bool test_tensor_move_semantics(IDevice* device) { pass &= host_a_copy_data == bfloat_data; // dev tensor to dev tensor move constructor - Tensor dev_a = Tensor(OwnedStorage{bfloat_data}, single_tile_shape, DataType::BFLOAT16, Layout::TILE).to(device); + Tensor dev_a = + Tensor(OwnedStorage{bfloat_data}, single_tile_shape, DataType::BFLOAT16, Layout::TILE).to_device(device); auto og_buffer_a = dev_a.buffer(); Tensor dev_a_copy = std::move(dev_a); pass &= dev_a_copy.buffer() == og_buffer_a; @@ -122,7 +123,7 @@ bool test_tensor_move_semantics(IDevice* device) { auto bfloat_data_four = owned_buffer::get_as(random_tensor_four); Tensor host_e = Tensor(random_tensor_four.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE); Tensor dev_e_copy = - Tensor(host_c_copy.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE).to(device); + Tensor(host_c_copy.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE).to_device(device); dev_e_copy = std::move(host_e); pass &= (dev_e_copy.storage_type() == StorageType::OWNED); auto dev_e_copy_data = owned_buffer::get_as(dev_e_copy); @@ -132,9 +133,9 @@ bool test_tensor_move_semantics(IDevice* device) { auto random_tensor_five = ttnn::random::uniform(bfloat16(-1.0f), bfloat16(1.0f), single_tile_shape); auto bfloat_data_five = owned_buffer::get_as(random_tensor_five); Tensor dev_b = - Tensor(random_tensor_four.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE).to(device); + Tensor(random_tensor_four.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE).to_device(device); Tensor dev_b_copy = - Tensor(dev_e_copy.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE).to(device); + Tensor(dev_e_copy.get_storage(), single_tile_shape, DataType::BFLOAT16, Layout::TILE).to_device(device); dev_b_copy = std::move(dev_b); pass &= (dev_b_copy.storage_type() == StorageType::DEVICE); auto dev_b_copy_on_host = dev_b_copy.cpu(); @@ -154,32 +155,32 @@ bool test_tensor_deallocate_semantics(IDevice* device) { MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}; // dev tensor allocate, deallocate, reallocate same address DRAM - Tensor dev_a = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, dram_mem_config); + Tensor dev_a = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, dram_mem_config); uint32_t address_a = dev_a.buffer()->address(); dev_a.deallocate(); - Tensor dev_b = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, dram_mem_config); + Tensor dev_b = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, dram_mem_config); uint32_t address_b = dev_b.buffer()->address(); pass &= address_a == address_b; // dev tensor allocate, allocate, deallocate, reallocate same address DRAM - Tensor dev_c = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, dram_mem_config); + Tensor dev_c = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, dram_mem_config); dev_b.deallocate(); - Tensor dev_d = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, dram_mem_config); + Tensor dev_d = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, dram_mem_config); uint32_t address_d = dev_d.buffer()->address(); pass &= address_b == address_d; // dev tensor allocate, deallocate, reallocate same address L1 - Tensor dev_e = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, l1_mem_config); + Tensor dev_e = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, l1_mem_config); uint32_t address_e = dev_e.buffer()->address(); dev_e.deallocate(); - Tensor dev_f = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, l1_mem_config); + Tensor dev_f = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, l1_mem_config); uint32_t address_f = dev_f.buffer()->address(); pass &= address_e == address_f; // dev tensor allocate, allocate, deallocate, reallocate same address DRAM - Tensor dev_g = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, l1_mem_config); + Tensor dev_g = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, l1_mem_config); dev_f.deallocate(); - Tensor dev_h = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, l1_mem_config); + Tensor dev_h = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, l1_mem_config); uint32_t address_h = dev_h.buffer()->address(); pass &= address_f == address_h; @@ -196,7 +197,7 @@ bool test_tensor_deallocate_and_close_device(IDevice* device) { MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}; // dev tensor allocate, deallocate, reallocate same address DRAM - Tensor dev_a = ttnn::random::random(single_tile_shape).to(Layout::TILE).to(device, dram_mem_config); + Tensor dev_a = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE).to_device(device, dram_mem_config); uint32_t address_a = dev_a.buffer()->address(); pass &= tt_metal::CloseDevice(device); dev_a.deallocate(); diff --git a/tests/tt_eager/tensors/test_host_device_loopback.cpp b/tests/tt_eager/tensors/test_host_device_loopback.cpp index 0b3ecf13eb6..98efd738541 100644 --- a/tests/tt_eager/tensors/test_host_device_loopback.cpp +++ b/tests/tt_eager/tensors/test_host_device_loopback.cpp @@ -21,8 +21,8 @@ bool test_single_tile_single_dram_bank_loopback(IDevice* device) { bool pass = true; ttnn::Shape single_tile_shape({1, 1, TILE_HEIGHT, TILE_WIDTH}); - Tensor host_a = ttnn::random::random(single_tile_shape).to(Layout::TILE); - Tensor device_a = host_a.to(device); + Tensor host_a = ttnn::random::random(single_tile_shape).to_layout(Layout::TILE); + Tensor device_a = host_a.to_device(device); Tensor loopbacked_a = device_a.cpu(); auto host_a_data = owned_buffer::get_as(host_a); auto loopbacked_a_data = owned_buffer::get_as(loopbacked_a); @@ -35,8 +35,8 @@ bool test_multi_tile_multi_dram_bank_loopback(IDevice* device) { bool pass = true; ttnn::Shape multi_tile_shape({1, 1, 4 * TILE_HEIGHT, 3 * TILE_WIDTH}); - Tensor host_a = ttnn::random::random(multi_tile_shape).to(Layout::TILE); - Tensor device_a = host_a.to(device); + Tensor host_a = ttnn::random::random(multi_tile_shape).to_layout(Layout::TILE); + Tensor device_a = host_a.to_device(device); Tensor loopbacked_a = device_a.cpu(); auto host_a_data = owned_buffer::get_as(host_a); auto loopbacked_a_data = owned_buffer::get_as(loopbacked_a); diff --git a/tests/tt_eager/tensors/test_ranks.cpp b/tests/tt_eager/tensors/test_ranks.cpp index 74dafe5ac98..ed27959ac92 100644 --- a/tests/tt_eager/tensors/test_ranks.cpp +++ b/tests/tt_eager/tensors/test_ranks.cpp @@ -25,8 +25,8 @@ bool test_2d_tensor(IDevice* device) { ttnn::Shape shape({30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 2; return pass; @@ -38,8 +38,8 @@ bool test_3d_tensor(IDevice* device) { ttnn::Shape shape({3, 30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 3; return pass; @@ -51,8 +51,8 @@ bool test_4d_tensor(IDevice* device) { ttnn::Shape shape({2, 3, 30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 4; return pass; @@ -64,8 +64,8 @@ bool test_5d_tensor(IDevice* device) { ttnn::Shape shape({2, 2, 3, 30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 5; return pass; @@ -77,8 +77,8 @@ bool test_6d_tensor(IDevice* device) { ttnn::Shape shape({2, 2, 2, 3, 30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 6; return pass; @@ -90,8 +90,8 @@ bool test_7d_tensor(IDevice* device) { ttnn::Shape shape({2, 2, 2, 2, 3, 30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 7; return pass; @@ -103,8 +103,8 @@ bool test_8d_tensor(IDevice* device) { ttnn::Shape shape({2, 2, 2, 2, 2, 3, 30, 30}); Tensor tensor = ttnn::random::random(shape); tensor = tensor.pad_to_tile(0.0f); - tensor = tensor.to(Layout::TILE); - tensor = tensor.to(device); + tensor = tensor.to_layout(Layout::TILE); + tensor = tensor.to_device(device); pass &= tensor.get_logical_shape().rank() == 8; return pass; diff --git a/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp b/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp index 79eef80267f..c097b2fc99a 100644 --- a/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp +++ b/tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp @@ -28,14 +28,14 @@ 07: a_cpu = np.array([[1,2,3,4],[5,6,7,8]], dtype=np.bfloat16) 08: 09: // define tensors on the device with CPU tensors -10: a_dev = torch.from_numpy(a_cpu).to(device) +10: a_dev = torch.from_numpy(a_cpu).to_device(device) 11: 12: c_dev = torch.sqrt(a_dev) 13: 14: print(c_dev[1][0]) 15: 16: d_cpu = np.array([[11,12,13,14],[15,16,17,18]]) -17: d_dev = d_cpu.to(device) +17: d_dev = d_cpu.to_device(device) 18: 19: e_dev = c_dev + d_dev 20: print(e_dev) @@ -105,7 +105,7 @@ void test_raw_host_memory_pointer() { /* Sanity Check End */ /* Run and Print Start */ - Tensor a_dev = a_cpu.to(device); + Tensor a_dev = a_cpu.to_device(device); Tensor c_dev = ttnn::sqrt(a_dev); diff --git a/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_bandwidth.py b/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_bandwidth.py new file mode 100644 index 00000000000..eeaa1c399af --- /dev/null +++ b/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_bandwidth.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys + +from loguru import logger +import pytest +import csv +from tt_metal.tools.profiler.process_device_log import import_log_run_stats +import tt_metal.tools.profiler.device_post_proc_config as device_post_proc_config +from tests.tt_metal.microbenchmarks.ethernet.test_ethernet_link_write_worker_with_transaction_id_common import ( + profile_results, +) + +from models.utility_functions import is_grayskull + +from tt_metal.tools.profiler.common import PROFILER_LOGS_DIR, PROFILER_DEVICE_SIDE_LOG + +profiler_log_path = PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG + +FILE_NAME = PROFILER_LOGS_DIR / "test_ethernet_link_write_worker_bandwidth.csv" + +if os.path.exists(FILE_NAME): + os.remove(FILE_NAME) + + +def run_erisc_write_worker_bw( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid, file_name +): + os.system(f"rm -rf {os.environ['TT_METAL_HOME']}/generated/profiler/.logs/profile_log_device.csv") + + test_latency = 0 + sample_size = sample_size_expected_bw[0] + sample_size_expected_bw = sample_size_expected_bw[1] + expected_bw_lower_bound = sample_size_expected_bw - 0.5 + expected_bw_upper_bound = sample_size_expected_bw + 0.5 + + ARCH_NAME = os.getenv("ARCH_NAME") + cmd = f"TT_METAL_DEVICE_PROFILER=1 \ + {os.environ['TT_METAL_HOME']}/build/test/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm_{ARCH_NAME} \ + {sample_count} \ + {sample_size} \ + {channel_count} \ + {num_directions} \ + {test_latency} \ + {enable_worker} \ + {disable_trid}" + rc = os.system(cmd) + if rc != 0: + logger.info("Error in running the test") + assert False + + main_loop_latency = profile_results( + sample_size, sample_count, channel_count, num_directions, test_latency, file_name + ) + main_loop_bw = sample_size / main_loop_latency + logger.info(f"sender_loop_latency {main_loop_latency}") + logger.info(f"sender_loop_bw {main_loop_bw}") + + assert expected_bw_lower_bound <= main_loop_bw <= expected_bw_upper_bound + + +##################################### BW test ####################################################### +# uni-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [256]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [1]) +@pytest.mark.parametrize("enable_worker", [1]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_bw", + [(16, 0.21), (128, 1.72), (256, 3.44), (512, 6.89), (1024, 11.73), (2048, 11.83), (4096, 12.04), (8192, 12.07)], +) +def test_erisc_write_worker_bw_uni_dir( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_bw( + sample_count, + sample_size_expected_bw, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +# bi-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1000]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [2]) +@pytest.mark.parametrize("enable_worker", [1]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_bw", + [(16, 0.13), (128, 1.03), (256, 2.08), (512, 4.15), (1024, 8.31), (2048, 11.40), (4096, 11.82)], +) +def test_erisc_write_worker_bw_bi_dir( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_bw( + sample_count, + sample_size_expected_bw, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +##################################### No Worker BW test ####################################################### +# uni-direction test for eth-sender <---> eth-receiver +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [256]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [1]) +@pytest.mark.parametrize("enable_worker", [0]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_bw", + [(16, 0.28), (128, 2.25), (256, 4.39), (512, 8.35), (1024, 11.74), (2048, 11.84), (4096, 12.04), (8192, 12.07)], +) +def test_erisc_bw_uni_dir( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_bw( + sample_count, + sample_size_expected_bw, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +# bi-direction test for eth-sender <---> eth-receiver +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1000]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [2]) +@pytest.mark.parametrize("enable_worker", [0]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_bw", + [(16, 0.19), (128, 1.59), (256, 3.19), (512, 6.39), (1024, 10.9), (2048, 11.4), (4096, 11.82)], +) +def test_erisc_bw_bi_dir( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_bw( + sample_count, + sample_size_expected_bw, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +##################################### No Transaction ID BW test ####################################################### +# uni-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [256]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [1]) +@pytest.mark.parametrize("enable_worker", [1]) +@pytest.mark.parametrize("disable_trid", [1]) +@pytest.mark.parametrize( + "sample_size_expected_bw", + [(16, 0.18), (128, 1.46), (256, 2.93), (512, 5.73), (1024, 9.15), (2048, 11.83), (4096, 12.04), (8192, 12.07)], +) +def test_erisc_write_worker_bw_uni_dir_no_trid( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_bw( + sample_count, + sample_size_expected_bw, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +# bi-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1000]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [2]) +@pytest.mark.parametrize("enable_worker", [1]) +@pytest.mark.parametrize("disable_trid", [1]) +@pytest.mark.parametrize( + "sample_size_expected_bw", + [(16, 0.10), (128, 0.87), (256, 1.73), (512, 3.44), (1024, 5.99), (2048, 9.70), (4096, 11.82)], +) +def test_erisc_write_worker_bw_bi_dir_no_trid( + sample_count, sample_size_expected_bw, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_bw( + sample_count, + sample_size_expected_bw, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) diff --git a/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id.py b/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_common.py similarity index 50% rename from tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id.py rename to tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_common.py index b532a5bc6e8..30343e6ae81 100644 --- a/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id.py +++ b/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_common.py @@ -17,11 +17,6 @@ profiler_log_path = PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG -FILE_NAME = PROFILER_LOGS_DIR / "test_ethernet_link_write_worker_latency.csv" - -if os.path.exists(FILE_NAME): - os.remove(FILE_NAME) - def append_to_csv(file_path, header, data, write_header=True): file_exists = os.path.isfile(file_path) @@ -40,7 +35,7 @@ def get_device_freq(): return freq -def profile_results(sample_size, sample_count, channel_count): +def profile_results(sample_size, sample_count, channel_count, num_directions, test_latency, file_name): freq = get_device_freq() / 1000.0 setup = device_post_proc_config.default_setup() setup.deviceInputLog = profiler_log_path @@ -61,51 +56,29 @@ def profile_results(sample_size, sample_count, channel_count): main_loop_cycle = devices_data["devices"][device_0]["cores"]["DEVICE"]["analysis"][main_test_body_string]["stats"][ "Average" ] - main_loop_latency = main_loop_cycle / freq / sample_count / channel_count - bw = sample_size / main_loop_latency - header = [ - "SAMPLE_SIZE", - "BW (B/c)", - ] - write_header = not os.path.exists(FILE_NAME) + if test_latency == 1: + main_loop_latency = main_loop_cycle / freq + header = [ + "NUM_DIRECTIONS", + "SAMPLE_SIZE", + "LATENCY (ns)", + ] + res = main_loop_latency + else: + main_loop_latency = main_loop_cycle / freq / sample_count / channel_count + bw = sample_size / main_loop_latency + header = [ + "NUM_DIRECTIONS", + "SAMPLE_SIZE", + "BW (B/c)", + ] + res = bw + write_header = not os.path.exists(file_name) append_to_csv( - FILE_NAME, + file_name, header, - [sample_size, bw], + [num_directions, sample_size, res], write_header, ) return main_loop_latency - - -@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") -@pytest.mark.parametrize("sample_count", [256]) -@pytest.mark.parametrize("channel_count", [16]) -@pytest.mark.parametrize( - "sample_size_expected_latency", - [(16, 86.2), (128, 86.2), (256, 86.4), (512, 86.5), (1024, 87.2), (2048, 172.9), (4096, 339.9), (8192, 678.4)], -) -def test_erisc_write_worker_latency(sample_count, sample_size_expected_latency, channel_count): - os.system(f"rm -rf {os.environ['TT_METAL_HOME']}/generated/profiler/.logs/profile_log_device.csv") - - sample_size = sample_size_expected_latency[0] - expected_latency = sample_size_expected_latency[1] - expected_latency_lower_bound = expected_latency - 0.5 - expected_latency_upper_bound = expected_latency + 0.5 - - ARCH_NAME = os.getenv("ARCH_NAME") - cmd = f"TT_METAL_DEVICE_PROFILER=1 \ - {os.environ['TT_METAL_HOME']}/build/test/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm_{ARCH_NAME} \ - {sample_count} \ - {sample_size} \ - {channel_count} " - rc = os.system(cmd) - if rc != 0: - logger.info("Error in running the test") - assert False - - main_loop_latency = profile_results(sample_size, sample_count, channel_count) - logger.info(f"sender_loop_latency {main_loop_latency}") - logger.info(f"result BW (B/c): {sample_size / main_loop_latency}") - - assert expected_latency_lower_bound <= main_loop_latency <= expected_latency_upper_bound diff --git a/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_latency.py b/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_latency.py new file mode 100644 index 00000000000..190a7f265f9 --- /dev/null +++ b/tests/tt_metal/microbenchmarks/ethernet/test_ethernet_link_write_worker_with_transaction_id_latency.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys + +from loguru import logger +import pytest +import csv +from tt_metal.tools.profiler.process_device_log import import_log_run_stats +import tt_metal.tools.profiler.device_post_proc_config as device_post_proc_config +from tests.tt_metal.microbenchmarks.ethernet.test_ethernet_link_write_worker_with_transaction_id_common import ( + profile_results, +) + +from models.utility_functions import is_grayskull + +from tt_metal.tools.profiler.common import PROFILER_LOGS_DIR, PROFILER_DEVICE_SIDE_LOG + +profiler_log_path = PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG + +FILE_NAME = PROFILER_LOGS_DIR / "test_ethernet_link_write_worker_latency.csv" + +if os.path.exists(FILE_NAME): + os.remove(FILE_NAME) + + +def run_erisc_write_worker_latency( + sample_count, sample_size_expected_latency, channel_count, num_directions, enable_worker, disable_trid, file_name +): + os.system(f"rm -rf {os.environ['TT_METAL_HOME']}/generated/profiler/.logs/profile_log_device.csv") + + test_latency = 1 + sample_size = sample_size_expected_latency[0] + sample_size_expected_latency = sample_size_expected_latency[1] + diff = sample_size_expected_latency * 0.1 + expected_latency_lower_bound = sample_size_expected_latency - diff + expected_latency_upper_bound = sample_size_expected_latency + diff + + ARCH_NAME = os.getenv("ARCH_NAME") + cmd = f"TT_METAL_DEVICE_PROFILER=1 \ + {os.environ['TT_METAL_HOME']}/build/test/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm_{ARCH_NAME} \ + {sample_count} \ + {sample_size} \ + {channel_count} \ + {num_directions} \ + {test_latency} \ + {enable_worker} \ + {disable_trid} " + rc = os.system(cmd) + if rc != 0: + logger.info("Error in running the test") + assert False + + main_loop_latency = profile_results( + sample_size, sample_count, channel_count, num_directions, test_latency, file_name + ) + logger.info(f"sender_loop_latency {main_loop_latency}") + + assert expected_latency_lower_bound <= main_loop_latency <= expected_latency_upper_bound + + +# uni-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [1]) +@pytest.mark.parametrize("enable_worker", [1]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_latency", + [ + (16, 984.0), + (128, 1002.0), + (256, 1019.0), + (512, 1074.0), + (1024, 1164.0), + (2048, 1308.0), + (4096, 1560.0), + (8192, 2048.0), + ], +) +def test_erisc_write_worker_latency_uni_dir( + sample_count, sample_size_expected_latency, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_latency( + sample_count, + sample_size_expected_latency, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +# bi-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [2]) +@pytest.mark.parametrize("enable_worker", [1]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_latency", + [(16, 1077.0), (128, 1079.0), (256, 1077.0), (512, 1175.0), (1024, 1231.0), (2048, 1389.0), (4096, 1596.0)], +) +def test_erisc_write_worker_latency_bi_dir( + sample_count, sample_size_expected_latency, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_latency( + sample_count, + sample_size_expected_latency, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +# uni-direction test for eth-sender <---> eth-receiver +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [1]) +@pytest.mark.parametrize("enable_worker", [0]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_latency", + [ + (16, 894.0), + (128, 911.0), + (256, 966.0), + (512, 984.0), + (1024, 1074.0), + (2048, 1200.0), + (4096, 1362.0), + (8192, 1686.0), + ], +) +def test_erisc_latency_uni_dir( + sample_count, sample_size_expected_latency, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_latency( + sample_count, + sample_size_expected_latency, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) + + +# bi-direction test for eth-sender <---> eth-receiver ---> worker +@pytest.mark.skipif(is_grayskull(), reason="Unsupported on GS") +@pytest.mark.parametrize("sample_count", [1]) +@pytest.mark.parametrize("channel_count", [16]) +@pytest.mark.parametrize("num_directions", [2]) +@pytest.mark.parametrize("enable_worker", [0]) +@pytest.mark.parametrize("disable_trid", [0]) +@pytest.mark.parametrize( + "sample_size_expected_latency", + [(16, 918.0), (128, 919.0), (256, 952.0), (512, 988.0), (1024, 1122.0), (2048, 1224.0), (4096, 1394.0)], +) +def test_erisc_latency_bi_dir( + sample_count, sample_size_expected_latency, channel_count, num_directions, enable_worker, disable_trid +): + run_erisc_write_worker_latency( + sample_count, + sample_size_expected_latency, + channel_count, + num_directions, + enable_worker, + disable_trid, + FILE_NAME, + ) diff --git a/tests/tt_metal/tools/profiler/test_device_profiler.py b/tests/tt_metal/tools/profiler/test_device_profiler.py index 5a31fcdbd6d..f235f7a29b5 100644 --- a/tests/tt_metal/tools/profiler/test_device_profiler.py +++ b/tests/tt_metal/tools/profiler/test_device_profiler.py @@ -21,7 +21,7 @@ clear_profiler_runtime_artifacts, ) -from models.utility_functions import skip_for_grayskull +from models.utility_functions import skip_for_grayskull, skip_for_blackhole PROG_EXMP_DIR = "programming_examples/profiler" @@ -82,6 +82,7 @@ def test_multi_op(): REF_COUNT_DICT = { "grayskull": [108 * OP_COUNT * RUN_COUNT, 88 * OP_COUNT * RUN_COUNT], "wormhole_b0": [72 * OP_COUNT * RUN_COUNT, 64 * OP_COUNT * RUN_COUNT, 56 * OP_COUNT * RUN_COUNT], + "blackhole": [130 * OP_COUNT * RUN_COUNT, 120 * OP_COUNT * RUN_COUNT, 110 * OP_COUNT * RUN_COUNT], } ENV_VAR_ARCH_NAME = os.getenv("ARCH_NAME") @@ -152,6 +153,11 @@ def test_full_buffer(): 64 * OP_COUNT * RISC_COUNT * ZONE_COUNT, 56 * OP_COUNT * RISC_COUNT * ZONE_COUNT, ], + "blackhole": [ + 130 * OP_COUNT * RISC_COUNT * ZONE_COUNT, + 120 * OP_COUNT * RISC_COUNT * ZONE_COUNT, + 110 * OP_COUNT * RISC_COUNT * ZONE_COUNT, + ], } ENV_VAR_ARCH_NAME = os.getenv("ARCH_NAME") @@ -189,6 +195,10 @@ def test_dispatch_cores(): "Tensix CQ Dispatch": 16, "Tensix CQ Prefetch": 25, }, + "blackhole": { + "Tensix CQ Dispatch": 16, + "Tensix CQ Prefetch": 25, + }, } ENV_VAR_ARCH_NAME = os.getenv("ARCH_NAME") @@ -216,6 +226,7 @@ def test_dispatch_cores(): os.environ["TT_METAL_DEVICE_PROFILER_DISPATCH"] = "0" +@skip_for_blackhole() @skip_for_grayskull() def test_ethernet_dispatch_cores(): REF_COUNT_DICT = { @@ -297,20 +308,29 @@ def test_timestamped_events(): OP_COUNT = 2 RISC_COUNT = 5 ZONE_COUNT = 100 - ERISC_COUNTS = [0, 1, 5] - TENSIX_COUNTS = [72, 64, 56] + WH_ERISC_COUNTS = [0, 1, 5] + WH_TENSIX_COUNTS = [72, 64, 56] + BH_ERISC_COUNTS = [0, 1, 5] + BH_TENSIX_COUNTS = [130, 120, 110] + + WH_COMBO_COUNTS = [] + for T in WH_TENSIX_COUNTS: + for E in WH_ERISC_COUNTS: + WH_COMBO_COUNTS.append((T, E)) - COMBO_COUNTS = [] - for T in TENSIX_COUNTS: - for E in ERISC_COUNTS: - COMBO_COUNTS.append((T, E)) + BH_COMBO_COUNTS = [] + for T in BH_TENSIX_COUNTS: + for E in BH_ERISC_COUNTS: + BH_COMBO_COUNTS.append((T, E)) REF_COUNT_DICT = { "grayskull": [108 * OP_COUNT * RISC_COUNT * ZONE_COUNT, 88 * OP_COUNT * RISC_COUNT * ZONE_COUNT], - "wormhole_b0": [(T * RISC_COUNT + E) * OP_COUNT * ZONE_COUNT for T, E in COMBO_COUNTS], + "wormhole_b0": [(T * RISC_COUNT + E) * OP_COUNT * ZONE_COUNT for T, E in WH_COMBO_COUNTS], + "blackhole": [(T * RISC_COUNT + E) * OP_COUNT * ZONE_COUNT for T, E in BH_COMBO_COUNTS], } REF_ERISC_COUNT = { - "wormhole_b0": [C * OP_COUNT * ZONE_COUNT for C in ERISC_COUNTS], + "wormhole_b0": [C * OP_COUNT * ZONE_COUNT for C in WH_ERISC_COUNTS], + "blackhole": [C * OP_COUNT * ZONE_COUNT for C in BH_ERISC_COUNTS], } ENV_VAR_ARCH_NAME = os.getenv("ARCH_NAME") diff --git a/tests/tt_metal/tt_metal/CMakeLists.txt b/tests/tt_metal/tt_metal/CMakeLists.txt index 1e1da2ac982..e162b7cbc13 100644 --- a/tests/tt_metal/tt_metal/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/CMakeLists.txt @@ -69,6 +69,7 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/llk) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/perf_microbenchmark) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/stl) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/noc) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/lightmetal) add_custom_target( metal_tests @@ -92,4 +93,5 @@ add_custom_target( unit_tests_llk unit_tests_stl unit_tests_noc + unit_tests_lightmetal ) diff --git a/tests/tt_metal/tt_metal/api/CMakeLists.txt b/tests/tt_metal/tt_metal/api/CMakeLists.txt index 16118245df1..b1a63fe8274 100644 --- a/tests/tt_metal/tt_metal/api/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/api/CMakeLists.txt @@ -26,6 +26,7 @@ set(UNIT_TESTS_API_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_noc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_runtime_args.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_semaphores.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_shape_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_sharded_l1_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_simple_dram_buffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_simple_l1_buffer.cpp diff --git a/tests/tt_metal/tt_metal/api/allocator/test_l1_banking_allocator.cpp b/tests/tt_metal/tt_metal/api/allocator/test_l1_banking_allocator.cpp index 432104cf3e5..d0cea1bdf29 100644 --- a/tests/tt_metal/tt_metal/api/allocator/test_l1_banking_allocator.cpp +++ b/tests/tt_metal/tt_metal/api/allocator/test_l1_banking_allocator.cpp @@ -19,7 +19,7 @@ uint64_t get_alloc_limit(const tt::tt_metal::IDevice* device) { auto dispatch_core_config = dispatch_core_manager::instance().get_dispatch_core_config(device->id()); auto storage_core_bank_size = tt::get_storage_core_bank_size(device->id(), device->num_hw_cqs(), dispatch_core_config); - const uint32_t allocator_alignment = device->allocator()->get_config().alignment; + const uint32_t allocator_alignment = device->allocator()->get_alignment(BufferType::L1); const uint32_t interleaved_l1_bank_size = storage_core_bank_size.has_value() ? storage_core_bank_size.value() : (soc_desc.worker_l1_size - l1_unreserved_base); diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_shape_base.cpp b/tests/tt_metal/tt_metal/api/test_shape_base.cpp similarity index 98% rename from tests/ttnn/unit_tests/gtests/tensor/test_shape_base.cpp rename to tests/tt_metal/tt_metal/api/test_shape_base.cpp index f3f36aeb535..455a7714c1d 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_shape_base.cpp +++ b/tests/tt_metal/tt_metal/api/test_shape_base.cpp @@ -5,7 +5,7 @@ #include #include "gtest/gtest.h" -#include "ttnn/tensor/shape/shape_base.hpp" +#include TEST(TensorShapeBaseTests, General4D) { tt::tt_metal::ShapeBase vec({20, 30, 40, 50}); diff --git a/tests/tt_metal/tt_metal/debug_tools/watcher/test_noc_sanitize.cpp b/tests/tt_metal/tt_metal/debug_tools/watcher/test_noc_sanitize.cpp index 2c3f3e95bb5..5962ae29275 100644 --- a/tests/tt_metal/tt_metal/debug_tools/watcher/test_noc_sanitize.cpp +++ b/tests/tt_metal/tt_metal/debug_tools/watcher/test_noc_sanitize.cpp @@ -118,9 +118,13 @@ void RunTestOnCore(WatcherFixture* fixture, IDevice* device, CoreCoord &core, bo case SanitizeZeroL1Write: output_l1_buffer_addr = 0; break; case SanitizeMailboxWrite: // This is illegal because we'd be writing to the mailbox memory - l1_buffer_addr = hal.get_dev_addr( - (is_eth_core) ? HalProgrammableCoreType::ACTIVE_ETH : HalProgrammableCoreType::TENSIX, - HalL1MemAddrType::MAILBOX); + if (is_eth_core) { + l1_buffer_addr = std::min( + hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::MAILBOX), + hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalL1MemAddrType::MAILBOX)); + } else { + l1_buffer_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::MAILBOX); + } break; default: log_warning(LogTest, "Unrecognized feature to test ({}), skipping...", feature); diff --git a/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_sub_device.cpp b/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_sub_device.cpp index 440b93639a9..12c3bdaa3cd 100644 --- a/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_sub_device.cpp +++ b/tests/tt_metal/tt_metal/dispatch/dispatch_buffer/test_sub_device.cpp @@ -84,7 +84,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceAllocations) { device->load_sub_device_manager(sub_device_manager_1); auto buffer_1 = CreateBuffer(shard_config_1, SubDeviceId{0}); - EXPECT_EQ(buffer_1->address(), max_addr - buffer_1->aligned_page_size()); + EXPECT_TRUE(buffer_1->address() <= max_addr - buffer_1->aligned_page_size()); EnqueueWriteBuffer(device->command_queue(), buffer_1, input_1, false); std::vector output_1; EnqueueReadBuffer(device->command_queue(), buffer_1, output_1, true); @@ -105,7 +105,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceAllocations) { device->load_sub_device_manager(sub_device_manager_2); auto buffer_3 = CreateBuffer(shard_config_2, SubDeviceId{1}); - EXPECT_EQ(buffer_3->address(), max_addr - buffer_3->aligned_page_size()); + EXPECT_TRUE(buffer_3->address() <= max_addr - buffer_3->aligned_page_size()); EnqueueWriteBuffer(device->command_queue(), buffer_3, input_2, false); std::vector output_2; EnqueueReadBuffer(device->command_queue(), buffer_3, output_2, true); @@ -118,7 +118,7 @@ TEST_F(CommandQueueSingleCardFixture, TensixTestSubDeviceAllocations) { } auto buffer_4 = CreateBuffer(shard_config_1, SubDeviceId{0}); - EXPECT_EQ(buffer_4->address(), max_addr - buffer_4->aligned_page_size()); + EXPECT_TRUE(buffer_4->address() <= max_addr - buffer_4->aligned_page_size()); EXPECT_THROW(CreateBuffer(interleaved_config, SubDeviceId{0}), std::exception); } diff --git a/tests/tt_metal/tt_metal/lightmetal/CMakeLists.txt b/tests/tt_metal/tt_metal/lightmetal/CMakeLists.txt new file mode 100644 index 00000000000..c8d1015f344 --- /dev/null +++ b/tests/tt_metal/tt_metal/lightmetal/CMakeLists.txt @@ -0,0 +1,23 @@ +set(UNIT_TESTS_LIGHTMETAL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_lightmetal.cpp) + +add_executable(unit_tests_lightmetal ${UNIT_TESTS_LIGHTMETAL_SRC}) +TT_ENABLE_UNITY_BUILD(unit_tests_lightmetal) + +target_link_libraries(unit_tests_lightmetal PUBLIC test_metal_common_libs) + +target_include_directories( + unit_tests_lightmetal + PRIVATE + "$" + ${PROJECT_SOURCE_DIR}/tests + ${PROJECT_SOURCE_DIR}/tests/tt_metal/tt_metal/common + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}/common +) + +set_target_properties( + unit_tests_lightmetal + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY + ${PROJECT_BINARY_DIR}/test/tt_metal +) diff --git a/tests/tt_metal/tt_metal/lightmetal/lightmetal_fixture.hpp b/tests/tt_metal/tt_metal/lightmetal/lightmetal_fixture.hpp new file mode 100644 index 00000000000..add02b77e4b --- /dev/null +++ b/tests/tt_metal/tt_metal/lightmetal/lightmetal_fixture.hpp @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "dispatch_fixture.hpp" +#include +#include +#include +#include +#include +#include +#include +#include "lightmetal/lightmetal_replay.hpp" +#include "command_queue_fixture.hpp" +#include + +class SingleDeviceLightMetalFixture : public CommandQueueFixture { +protected: + bool replay_binary_; + std::string trace_bin_path_; + bool write_bin_to_disk_; + + void SetUp() override { + this->validate_dispatch_mode(); + this->arch_ = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); + } + + void CreateDeviceAndBeginCapture( + const size_t trace_region_size, const bool replay_binary = true, const std::string trace_bin_path = "") { + // Skip writing to disk by default, unless user sets env var for local testing + write_bin_to_disk_ = tt::parse_env("LIGHTMETAL_SAVE_BINARY", false); + + // If user didn't provide a specific trace bin path, set a default here based on test name + if (trace_bin_path == "") { + const auto test_info = ::testing::UnitTest::GetInstance()->current_test_info(); + auto trace_filename = test_info ? std::string(test_info->name()) + ".bin" : "lightmetal_trace.bin"; + this->trace_bin_path_ = "/tmp/" + trace_filename; + } + + this->create_device(trace_region_size); + this->replay_binary_ = replay_binary && !tt::parse_env("LIGHTMETAL_DISABLE_RUN", false); + // TODO (kmabee) - revisit placement. CreateDevice() path calls CreateKernel() on programs not + // created with CreateProgram() traced API which leads to "program not in global_id map" + LightMetalBeginCapture(); + } + + // End light metal tracing, write to optional filename and optionally run from binary blob + void TearDown() override { + LightMetalBinary binary = LightMetalEndCapture(); + + if (binary.is_empty()) { + FAIL() << "Light Metal Binary is empty for test, unexpected."; + } + if (write_bin_to_disk_ && !this->trace_bin_path_.empty() && !binary.is_empty()) { + log_info(tt::LogTest, "Writing light metal binary {} bytes to {}", binary.size(), this->trace_bin_path_); + binary.save_to_file(this->trace_bin_path_); + } + + if (!this->IsSlowDispatch()) { + tt::tt_metal::CloseDevice(this->device_); + } + + // We could gaurd this to not attempt to replay empty binary, and still allow test to pass + // but, would rather catch the case if the feature gets disabled at compile time. + if (replay_binary_) { + RunLightMetalBinary(std::move(binary)); + } + } + + // Mimic the light-metal standalone run replay tool by executing the binary. + void RunLightMetalBinary(LightMetalBinary&& binary) { + tt::tt_metal::LightMetalReplay lm_replay(std::move(binary)); + if (!lm_replay.execute_binary()) { + FAIL() << "Light Metal Binary failed to execute or encountered errors."; + } else { + log_info(tt::LogMetalTrace, "Light Metal Binary executed successfully!"); + } + } +}; diff --git a/tests/tt_metal/tt_metal/lightmetal/test_lightmetal.cpp b/tests/tt_metal/tt_metal/lightmetal/test_lightmetal.cpp new file mode 100644 index 00000000000..083e072a322 --- /dev/null +++ b/tests/tt_metal/tt_metal/lightmetal/test_lightmetal.cpp @@ -0,0 +1,379 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "lightmetal_fixture.hpp" +#include +#include "env_lib.hpp" +#include "gtest/gtest.h" +#include +#include +#include +#include +#include +#include +#include "lightmetal_capture_utils.hpp" + +using std::vector; +using namespace tt; +using namespace tt::tt_metal; + +namespace tt::tt_metal { +namespace { + +// Single RISC, no CB's here. Very simple. +Program create_simple_datamovement_program(Buffer& input, Buffer& output, Buffer& l1_buffer) { + Program program = CreateProgram(); + IDevice* device = input.device(); + constexpr CoreCoord core = {0, 0}; + + KernelHandle dram_copy_kernel_id = CreateKernel( + program, + "tt_metal/programming_examples/loopback/kernels/loopback_dram_copy.cpp", + core, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); + + // Since all interleaved buffers have size == page_size, they are entirely contained in the first DRAM bank + const uint32_t input_bank_id = 0; + const uint32_t output_bank_id = 0; + + // Handle Runtime Args + const std::vector runtime_args = { + l1_buffer.address(), input.address(), input_bank_id, output.address(), output_bank_id, l1_buffer.size()}; + + // Note - this interface doesn't take Buffer, just data. + SetRuntimeArgs(program, dram_copy_kernel_id, core, runtime_args); + + return program; +} + +// Copied from test_EnqueueTrace.cpp +Program create_simple_unary_program(Buffer& input, Buffer& output, Buffer* cb_input_buffer = nullptr) { + Program program = CreateProgram(); + IDevice* device = input.device(); + CoreCoord worker = {0, 0}; + auto reader_kernel = CreateKernel( + program, + "tt_metal/kernels/dataflow/reader_unary.cpp", + worker, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_1, .noc = NOC::RISCV_1_default}); + + auto writer_kernel = CreateKernel( + program, + "tt_metal/kernels/dataflow/writer_unary.cpp", + worker, + DataMovementConfig{.processor = DataMovementProcessor::RISCV_0, .noc = NOC::RISCV_0_default}); + + auto sfpu_kernel = CreateKernel( + program, + "tt_metal/kernels/compute/eltwise_sfpu.cpp", + worker, + ComputeConfig{ + .math_approx_mode = true, + .compile_args = {1, 1}, + .defines = {{"SFPU_OP_EXP_INCLUDE", "1"}, {"SFPU_OP_CHAIN_0", "exp_tile_init(); exp_tile(0);"}}}); + + CircularBufferConfig input_cb_config = CircularBufferConfig(2048, {{tt::CBIndex::c_0, tt::DataFormat::Float16_b}}) + .set_page_size(tt::CBIndex::c_0, 2048); + + // For testing dynamic CB for which CB config has a shadow buffer ptr to test. + if (cb_input_buffer) { + input_cb_config.set_globally_allocated_address(*cb_input_buffer); + } + + CoreRange core_range({0, 0}); + CreateCircularBuffer(program, core_range, input_cb_config); + std::shared_ptr writer_runtime_args = std::make_shared(); + std::shared_ptr reader_runtime_args = std::make_shared(); + + *writer_runtime_args = {&output, (uint32_t)0, output.num_pages()}; + + *reader_runtime_args = {&input, (uint32_t)0, input.num_pages()}; + + SetRuntimeArgs(device, detail::GetKernel(program, writer_kernel), worker, writer_runtime_args); + SetRuntimeArgs(device, detail::GetKernel(program, reader_kernel), worker, reader_runtime_args); + + CircularBufferConfig output_cb_config = CircularBufferConfig(2048, {{tt::CBIndex::c_16, tt::DataFormat::Float16_b}}) + .set_page_size(tt::CBIndex::c_16, 2048); + + CreateCircularBuffer(program, core_range, output_cb_config); + return program; +} + +void write_junk_to_buffer(CommandQueue& command_queue, Buffer& buffer) { + vector dummy_write_data(buffer.size() / sizeof(uint32_t), 0xDEADBEEF); + vector dummy_read_data(buffer.size() / sizeof(uint32_t), 0); + EnqueueWriteBuffer(command_queue, buffer, dummy_write_data.data(), true); + EnqueueReadBuffer(command_queue, buffer, dummy_read_data.data(), true); + for (size_t i = 0; i < dummy_read_data.size(); i++) { + log_trace(tt::LogMetalTrace, "i: {:3d} output: {:x} after write+read of dummy data", i, dummy_read_data[i]); + } + EXPECT_TRUE(dummy_write_data == dummy_read_data); +} + +// TODO (kmabee) - consider looping over blocking_flags in some/all tests once stable. +constexpr bool kBlocking = true; +constexpr bool kNonBlocking = false; +vector blocking_flags = {kBlocking, kNonBlocking}; + +using LightMetalBasicTest = SingleDeviceLightMetalFixture; + +// Test that create buffer, write, readback, and verify works when traced + replayed. +TEST_F(LightMetalBasicTest, CreateBufferEnqueueWriteRead) { + CreateDeviceAndBeginCapture(4096); + + CommandQueue& command_queue = this->device_->command_queue(); + uint32_t num_loops = 5; + bool keep_buffers_alive = true; + std::vector> buffers_vec; + + for (uint32_t loop_idx = 0; loop_idx < num_loops; loop_idx++) { + log_debug(tt::LogTest, "Running loop: {}", loop_idx); + + // Switch to use top level CreateBuffer API that has trace support. + uint32_t size_bytes = 64; // 16 elements. + auto buffer = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + log_debug( + tt::LogTest, + "created buffer loop: {} with size: {} bytes addr: 0x{:x}", + loop_idx, + buffer->size(), + buffer->address()); + + if (keep_buffers_alive && loop_idx > 1) { + buffers_vec.push_back(buffer); + } + + // We don't want to capture inputs in binary, but do it to start for testing. + uint32_t start_val = loop_idx * 100; + vector input_data(buffer->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = start_val + i; + } + log_debug(tt::LogTest, "initialize input_data with {} elements start_val: {}", input_data.size(), start_val); + + vector readback_data; + readback_data.resize(input_data.size()); // This is required. + + // Write data to buffer, then read outputs and verify against expected. + EnqueueWriteBuffer(command_queue, *buffer, input_data.data(), /*blocking=*/true); + // This will verify that readback matches between capture + replay + LightMetalCompareToCapture(command_queue, *buffer, readback_data.data()); + + EXPECT_TRUE(input_data == readback_data); + + // For dev/debug go ahead and print the results. Had a replay bug, was seeing wrong data. + for (size_t i = 0; i < readback_data.size(); i++) { + log_debug(tt::LogMetalTrace, "loop: {} rd_data i: {:3d} => data: {}", loop_idx, i, readback_data[i]); + } + } + + // If any Buffers were kept alive for testing, Deallocate them now to exercise that path for capture/replay. + if (keep_buffers_alive) { + log_info(tt::LogTest, "Explicitly deallocating {} buffers now.", buffers_vec.size()); + for (auto& buffer : buffers_vec) { + DeallocateBuffer(*buffer); + } + } + + Finish(command_queue); +} + +// Test simple case of single datamovement program on single RISC works for trace + replay. +TEST_F(LightMetalBasicTest, SingleRISCDataMovement) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto l1_buffer = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::L1}); + log_debug( + tt::LogTest, + "Created 3 Buffers. input: 0x{:x} output: 0x{:x} l1_buffer: 0x{:x}", + input->address(), + output->address(), + l1_buffer->address()); + + CommandQueue& command_queue = this->device_->command_queue(); + + Program simple_program = create_simple_datamovement_program(*input, *output, *l1_buffer); + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + vector eager_output_data; + eager_output_data.resize(input_data.size()); + + // Write data to buffer, enqueue program, then read outputs and verify against expected. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay + LightMetalCompareToCapture(command_queue, *output, eager_output_data.data()); + + EXPECT_TRUE(eager_output_data == input_data); + + // For dev/debug go ahead and print the results + for (size_t i = 0; i < eager_output_data.size(); i++) { + log_debug(tt::LogMetalTrace, "i: {:3d} input: {} output: {}", i, input_data[i], eager_output_data[i]); + } + + Finish(command_queue); +} + +// Test simple case of 3 riscs used for datamovement and compute works for trace + replay. +TEST_F(LightMetalBasicTest, ThreeRISCDataMovementCompute) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + + CommandQueue& command_queue = this->device_->command_queue(); + + // TODO (kmabee) - There is issue with using make_shared, revisit this. + // auto simple_program = std::make_shared(create_simple_unary_program(*input, + // *output)); + auto simple_program = create_simple_unary_program(*input, *output); + + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + // Write data to buffer, enqueue program, then read outputs. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay + LightMetalCompareToCapture(command_queue, *output); // No read return + + Finish(command_queue); +} + +// Test simple case of 3 riscs used for datamovement and compute works for trace + replay. Also include dynamic CB. +TEST_F(LightMetalBasicTest, ThreeRISCDataMovementComputeDynamicCB) { + CreateDeviceAndBeginCapture(4096); + + uint32_t buf_size_bytes = 64; // 16 elements. + uint32_t cb_size_bytes = 2048; + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, buf_size_bytes, buf_size_bytes, BufferType::DRAM}); + auto output = + CreateBuffer(InterleavedBufferConfig{this->device_, buf_size_bytes, buf_size_bytes, BufferType::DRAM}); + auto cb_in_buf = CreateBuffer(InterleavedBufferConfig{this->device_, cb_size_bytes, cb_size_bytes, BufferType::L1}); + log_info( + tt::LogTest, + "Created 3 Buffers. 0x{:x} 0x{:x} 0x{:x}", + input->address(), + output->address(), + cb_in_buf->address()); + + CommandQueue& command_queue = this->device_->command_queue(); + auto simple_program = create_simple_unary_program(*input, *output, cb_in_buf.get()); + + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + // Write data to buffer, enqueue program, then read outputs. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay + LightMetalCompareToCapture(command_queue, *output); // No read return + + Finish(command_queue); +} + +// Test simple compute test with metal trace, but no explicit trace replay (added automatically by light metal trace). +TEST_F(LightMetalBasicTest, SingleProgramTraceCapture) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. Was 2048 in original test. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + + CommandQueue& command_queue = this->device_->command_queue(); + Program simple_program = create_simple_unary_program(*input, *output); + + // Setup input data for program with some simple values. + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + std::vector eager_output_data(input_data.size()); + + // Initial run w/o trace. Preloads binary cache, and captures golden output. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, simple_program, /*blocking=*/true); + // This will verify that outputs matches between capture + replay. + LightMetalCompareToCapture(command_queue, *output, eager_output_data.data()); + + // Write junk to output buffer to help make sure trace run from standalone binary works. + write_junk_to_buffer(command_queue, *output); + + // Now enable Metal Trace and run program again for capture. + uint32_t tid = BeginTraceCapture(this->device_, command_queue.id()); + EnqueueProgram(command_queue, simple_program, false); + EndTraceCapture(this->device_, command_queue.id(), tid); + + // Verify trace output during replay matches expected output from original capture. + LightMetalCompareToGolden(command_queue, *output, eager_output_data.data()); + + // Done + Finish(command_queue); + ReleaseTrace(this->device_, tid); +} + +// Test simple compute test with metal trace, but no explicit trace replay (added automatically by light metal trace). +TEST_F(LightMetalBasicTest, TwoProgramTraceCapture) { + CreateDeviceAndBeginCapture(4096); + + uint32_t size_bytes = 64; // 16 elements. Was 2048 in original test. + auto input = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto interm = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + auto output = CreateBuffer(InterleavedBufferConfig{this->device_, size_bytes, size_bytes, BufferType::DRAM}); + + CommandQueue& command_queue = this->device_->command_queue(); + + Program op0 = create_simple_unary_program(*input, *interm); + Program op1 = create_simple_unary_program(*interm, *output); + + // Setup input data for program with some simple values. + vector input_data(input->size() / sizeof(uint32_t), 0); + for (uint32_t i = 0; i < input_data.size(); i++) { + input_data[i] = i; + } + + std::vector eager_output_data(input_data.size()); + + // Initial run w/o trace. Preloads binary cache, and captures golden output. + EnqueueWriteBuffer(command_queue, *input, input_data.data(), /*blocking=*/true); + EnqueueProgram(command_queue, op0, /*blocking=*/true); + EnqueueProgram(command_queue, op1, /*blocking=*/true); + // This will verify that outputs matches between capture + replay. + LightMetalCompareToCapture(command_queue, *output, eager_output_data.data()); + Finish(command_queue); + + // Write junk to output buffer to help make sure trace run from standalone binary works. + write_junk_to_buffer(command_queue, *output); + + // Now enable Metal Trace and run program again for capture. + uint32_t tid = BeginTraceCapture(this->device_, command_queue.id()); + EnqueueProgram(command_queue, op0, false); + EnqueueProgram(command_queue, op1, false); + EndTraceCapture(this->device_, command_queue.id(), tid); + + // Verify trace output during replay matches expected output from original capture. + LightMetalCompareToGolden(command_queue, *output, eager_output_data.data()); + + // Done + Finish(command_queue); + ReleaseTrace(this->device_, tid); +} + +} // namespace +} // namespace tt::tt_metal diff --git a/tests/tt_metal/tt_metal/llk/CMakeLists.txt b/tests/tt_metal/tt_metal/llk/CMakeLists.txt index 2072fd8afca..cdf53614bb3 100644 --- a/tests/tt_metal/tt_metal/llk/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/llk/CMakeLists.txt @@ -10,6 +10,7 @@ set(UNIT_TESTS_LLK_SRC ${CMAKE_CURRENT_SOURCE_DIR}/test_single_core_binary_compute.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_single_core_matmul_compute.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_transpose.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test_unary_broadcast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/test_untilize_tilize.cpp ) diff --git a/tests/tt_metal/tt_metal/llk/test_unary_broadcast.cpp b/tests/tt_metal/tt_metal/llk/test_unary_broadcast.cpp new file mode 100644 index 00000000000..a2af8f5ec80 --- /dev/null +++ b/tests/tt_metal/tt_metal/llk/test_unary_broadcast.cpp @@ -0,0 +1,332 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include + +#include "device_fixture.hpp" +#include +#include "tt_metal/test_utils/comparison.hpp" +#include "tt_metal/test_utils/df/df.hpp" +#include "tt_metal/test_utils/stimulus.hpp" +#include "test_golden_impls.hpp" + +using std::map; +using namespace tt; +using namespace tt::test_utils; +using namespace tt::test_utils::df; +using namespace tt::tt_metal; + +namespace unit_tests::compute::broadcast { + +enum BroadcastDim : uint8_t { ROW, COL, SCALAR, NONE, NUM_DIMS }; + +const map broadcast_dim_to_type = { + {BroadcastDim::ROW, "BroadcastType::ROW"}, + {BroadcastDim::COL, "BroadcastType::COL"}, + {BroadcastDim::SCALAR, "BroadcastType::SCALAR"}, + {BroadcastDim::NONE, "BroadcastType::NONE"}}; + +struct UnaryBroadcastConfig { + BroadcastDim broadcast_dim_0; + BroadcastDim broadcast_dim_1; + tt::DataFormat in0_t; + tt::DataFormat in1_t; + tt::DataFormat out0_t; + tt::DataFormat out1_t; +}; + +// Assume 1Xn tiles. +template +std::vector get_broadcasted_vec(std::vector& src, const std::vector& shape, BroadcastDim dim) { + int num_tiles = shape.at(0); + int num_rows = shape.at(1); + int num_cols = shape.at(2); + int tile_elem_count = num_rows * num_cols; + + std::vector vBroadcast(num_tiles * num_cols * num_rows); + + if (dim == BroadcastDim::NONE) { + vBroadcast = src; + } else { + for (int t = 0; t < num_tiles; t++) { + int tile_offset = tile_elem_count * t; + for (int i = 0; i < num_rows; i++) { + for (int j = 0; j < num_cols; j++) { + T broadcast_value; + switch (dim) { + case BroadcastDim::ROW: { + broadcast_value = src[tile_offset + j]; + break; + } + case BroadcastDim::COL: { + broadcast_value = src[tile_offset + (i * num_cols)]; + break; + } + case BroadcastDim::SCALAR: { + broadcast_value = src[tile_offset]; + break; + } + default: { + TT_THROW("Unsupported BroadcastDim={}", dim); + break; + } + } + + vBroadcast[tile_offset + (i * num_cols + j)] = broadcast_value; + } + } + } + } + + return vBroadcast; +} + +// T_in : type of src vector +// T_out : type of data the packer will pack out +// Assume nx1 tiles, row major data layout. +template +std::vector get_tilized_packed_golden_broadcast( + std::vector& src, const std::vector& shape, BroadcastDim dim, tt::DataFormat T_out) { + static_assert( + std::is_same::value || std::is_same::value, + "Only float & Float_16b type as input allowed"); + std::vector tilized_packed_res; + unit_tests::compute::GoldenConfig config = {.num_tiles_r_dim = shape.at(0), .num_tiles_c_dim = 1}; + std::vector vBroadcast = get_broadcasted_vec(src, shape, dim); + if constexpr (std::is_same::value) { + if (T_out == tt::DataFormat::Float16_b) { + auto packed_vec = pack_vector(vBroadcast); + tilized_packed_res = unit_tests::compute::gold_standard_tilize(packed_vec, config); + } else if (T_out == tt::DataFormat::Bfp8_b) { + std::vector tempfp32v; + tempfp32v.resize(vBroadcast.size()); + for (int i = 0; i < vBroadcast.size(); i++) { + tempfp32v[i] = vBroadcast[i].to_float(); + } + tilized_packed_res = pack_fp32_vec_as_bfp8_tiles(tempfp32v, true, false); + } else { + TT_THROW("Testing infrastructure not setup for output data type {}", T_out); + } + } else if constexpr (std::is_same::value) { + if (T_out == tt::DataFormat::Float16_b) { + std::vector tempfp16bv; + tempfp16bv.resize(vBroadcast.size()); + for (int i = 0; i < vBroadcast.size(); i++) { + tempfp16bv[i] = vBroadcast[i]; + } + auto packed_vec = pack_vector(tempfp16bv); + tilized_packed_res = unit_tests::compute::gold_standard_tilize(packed_vec, config); + } else if (T_out == tt::DataFormat::Bfp8_b) { + tilized_packed_res = pack_fp32_vec_as_bfp8_tiles(vBroadcast, true, false); + } else { + TT_THROW("Testing infrastructure not setup for output data type {}", T_out); + } + } + return tilized_packed_res; +} + +bool check_is_close(std::vector& packed_golden, std::vector& device_res, tt::DataFormat T_out) { + bool result = true; + if (T_out == tt::DataFormat::Float16_b) { + result = is_close_packed_vectors( + packed_golden, device_res, [&](const bfloat16& a, const bfloat16& b) { return is_close(a, b, 0.0); }); + } else if (T_out == tt::DataFormat::Bfp8_b) { + // Host side may do nearest to even but device side may do nearest rounding, with rounding up + // in case of tie. Also need to note packer source format, which may lead to additional rounding. + float atol = 0.03125f; + auto gold_refloat = unpack_bfp8_tiles_into_float_vec(packed_golden, true, false); + auto res_refloat = unpack_bfp8_tiles_into_float_vec(device_res, true, false); + if (gold_refloat.size() != res_refloat.size()) { + TT_THROW( + "Mismatch in size of vectors for comparison A.size={} B.size={}", + gold_refloat.size(), + res_refloat.size()); + } + for (int i = 0; i < gold_refloat.size(); i++) { + if (std::fabs(gold_refloat[i] - res_refloat[i]) > atol) { + TT_THROW("Mismatch A={} B={} atol={}", gold_refloat[i], res_refloat[i], atol); + result = false; + break; + } + } + } else { + TT_THROW("Testing infrastructure not setup for output data type {}", T_out); + } + + return result; +} + +auto CreateDramBuffer(tt_metal::IDevice* device, tt::DataFormat dformat, uint32_t num_tiles) { + uint32_t single_tile_size = tile_size(dformat); + uint32_t dram_buffer_size = single_tile_size * num_tiles; + tt_metal::InterleavedBufferConfig dram_config{ + .device = device, + .size = dram_buffer_size, + .page_size = dram_buffer_size, + .buffer_type = tt_metal::BufferType::DRAM}; + + return CreateBuffer(dram_config); +} + +CBHandle CreateCircularBufferHelper( + Program& program, CoreCoord& core, uint32_t num_pages, tt::DataFormat dformat, uint32_t id) { + uint32_t page_size = tile_size(dformat); + tt_metal::CircularBufferConfig l1_cb_config = + tt_metal::CircularBufferConfig(num_pages * page_size, {{id, dformat}}).set_page_size(id, page_size); + return tt_metal::CreateCircularBuffer(program, core, l1_cb_config); +} + +void get_packed_tilized_input_output_pair( + tt::DataFormat in_t, + tt::DataFormat out_t, + uint32_t num_tiles, + BroadcastDim bcast_dim, + std::vector& packed_tilized_input, + std::vector& packed_tilized_output) { + constexpr uint32_t tile_width = 32; + constexpr uint32_t tile_height = 32; + constexpr uint32_t num_single_tile_elem = tile_width * tile_height; + if (in_t == tt::DataFormat::Float16_b) { + std::vector input = generate_uniform_random_vector( + 1.0f, 2.0f, num_tiles * num_single_tile_elem, std::chrono::system_clock::now().time_since_epoch().count()); + + unit_tests::compute::GoldenConfig config = {.num_tiles_r_dim = num_tiles, .num_tiles_c_dim = 1}; + auto packed_input = pack_vector(input); + packed_tilized_input = unit_tests::compute::gold_standard_tilize(packed_input, config); + packed_tilized_output = + get_tilized_packed_golden_broadcast(input, {num_tiles, tile_width, tile_height}, bcast_dim, out_t); + } else if (in_t == tt::DataFormat::Bfp8_b) { + packed_tilized_input = create_random_vector_of_bfp8(num_tiles * tile_size(in_t), false, 1, 1.0); + std::vector input = unpack_bfp8_tiles_into_float_vec(packed_tilized_input, true, false); + packed_tilized_output = + get_tilized_packed_golden_broadcast(input, {num_tiles, tile_width, tile_height}, bcast_dim, out_t); + } +} + +void run_single_core_unary_broadcast(tt_metal::IDevice* device, const UnaryBroadcastConfig& test_config) { + Program program = tt_metal::CreateProgram(); + + CoreCoord core = {0, 0}; + + constexpr uint32_t num_tiles = 32; + constexpr uint32_t num_blocks = 4; + constexpr uint32_t block_size = num_tiles / num_blocks; + tt::DataFormat in0_t = test_config.in0_t; + tt::DataFormat out0_t = test_config.out0_t; + tt::DataFormat in1_t = test_config.in1_t; + tt::DataFormat out1_t = test_config.out1_t; + + auto src_dram_buffer_0 = CreateDramBuffer(device, in0_t, num_tiles); + auto dst_dram_buffer_0 = CreateDramBuffer(device, out0_t, num_tiles); + auto src_dram_buffer_1 = CreateDramBuffer(device, in1_t, num_tiles); + auto dst_dram_buffer_1 = CreateDramBuffer(device, out1_t, num_tiles); + auto l1_src_cb_0 = CreateCircularBufferHelper(program, core, block_size * 2, in0_t, 0); + auto l1_dst_cb_0 = CreateCircularBufferHelper(program, core, block_size * 2, out0_t, 16); + auto l1_src_cb_1 = CreateCircularBufferHelper(program, core, block_size * 2, in1_t, 1); + auto l1_dst_cb_1 = CreateCircularBufferHelper(program, core, block_size * 2, out1_t, 17); + + std::map defines = { + {"BCAST_DIM_0", broadcast_dim_to_type.at(test_config.broadcast_dim_0)}, + {"BCAST_DIM_1", broadcast_dim_to_type.at(test_config.broadcast_dim_1)}}; + + auto reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_dual_unary.cpp", + core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, .noc = tt_metal::NOC::RISCV_1_default}); + + auto writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/dataflow/writer_dual_unary.cpp", + core, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}); + + auto binary_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/test_kernels/compute/unary_bcast.cpp", + core, + tt_metal::ComputeConfig{.compile_args = {num_blocks, block_size}, .defines = defines}); + + tt_metal::SetRuntimeArgs( + program, + reader_kernel, + core, + { + (uint32_t)(src_dram_buffer_0->address()), + (uint32_t)0, // dram bank id + (uint32_t)(src_dram_buffer_1->address()), + (uint32_t)0, // dram bank id + (uint32_t)num_tiles, // num tiles + }); + + tt_metal::SetRuntimeArgs( + program, + writer_kernel, + core, + { + (uint32_t)(dst_dram_buffer_0->address()), + (uint32_t)0, // dram bank id + (uint32_t)(dst_dram_buffer_1->address()), + (uint32_t)0, // dram bank id + (uint32_t)num_tiles, // num tiles + }); + + std::vector packed_tilized_input_0, golden_packed_tilized_output_0; + get_packed_tilized_input_output_pair( + in0_t, out0_t, num_tiles, test_config.broadcast_dim_0, packed_tilized_input_0, golden_packed_tilized_output_0); + tt_metal::detail::WriteToBuffer(src_dram_buffer_0, packed_tilized_input_0); + + std::vector packed_tilized_input_1, golden_packed_tilized_output_1; + get_packed_tilized_input_output_pair( + in1_t, out1_t, num_tiles, test_config.broadcast_dim_1, packed_tilized_input_1, golden_packed_tilized_output_1); + tt_metal::detail::WriteToBuffer(src_dram_buffer_1, packed_tilized_input_1); + + tt_metal::detail::LaunchProgram(device, program); + + std::vector dest_buffer_data_0; + tt_metal::detail::ReadFromBuffer(dst_dram_buffer_0, dest_buffer_data_0); + std::vector dest_buffer_data_1; + tt_metal::detail::ReadFromBuffer(dst_dram_buffer_1, dest_buffer_data_1); + + bool result = check_is_close(golden_packed_tilized_output_0, dest_buffer_data_0, out0_t); + result &= check_is_close(golden_packed_tilized_output_1, dest_buffer_data_1, out1_t); + + ASSERT_TRUE(result); +} +} // namespace unit_tests::compute::broadcast + +using namespace unit_tests::compute::broadcast; + +TEST_F(DeviceFixture, TensixComputeSingleTileUnaryBroadcast) { + if (this->arch_ == tt::ARCH::GRAYSKULL) { + GTEST_SKIP(); + } + + for (BroadcastDim bcast_dim : {BroadcastDim::NONE, BroadcastDim::ROW, BroadcastDim::COL, BroadcastDim::SCALAR}) { + for (tt::DataFormat in0_t_ : {tt::DataFormat::Bfp8_b, tt::DataFormat::Float16_b}) { + for (tt::DataFormat out0_t_ : {tt::DataFormat::Bfp8_b, tt::DataFormat::Float16_b}) { + UnaryBroadcastConfig test_config = { + .broadcast_dim_0 = bcast_dim, + .broadcast_dim_1 = (BroadcastDim)((bcast_dim + 1) % BroadcastDim::NUM_DIMS), + .in0_t = in0_t_, + .in1_t = (in0_t_ == tt::DataFormat::Bfp8_b) ? tt::DataFormat::Float16_b : tt::DataFormat::Bfp8_b, + .out0_t = out0_t_, + .out1_t = (out0_t_ == tt::DataFormat::Bfp8_b) ? tt::DataFormat::Float16_b : tt::DataFormat::Bfp8_b}; + + log_info( + "Testing UNARY BROADCAST BCAST_DIM_0={} in0_t={} out0_t={} | BCAST_DIM_1={} in1_t={} out1_t={}", + broadcast_dim_to_type.at(test_config.broadcast_dim_0), + test_config.in0_t, + test_config.out0_t, + broadcast_dim_to_type.at(test_config.broadcast_dim_1), + test_config.in1_t, + test_config.out1_t); + unit_tests::compute::broadcast::run_single_core_unary_broadcast(this->devices_.at(0), test_config); + } + } + } +} diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h index 9fadafc1a0c..2d9742cb83d 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h @@ -651,7 +651,7 @@ inline void generate_random_paged_payload( words_per_page); // Note: the dst address marches in unison regardless of whether or not a core is written to - uint32_t page_size_alignment_bytes = device->allocator()->get_config().alignment; + uint32_t page_size_alignment_bytes = device->allocator()->get_alignment(buf_type); for (uint32_t page_id = start_page; page_id < start_page + cmd.write_paged.pages; page_id++) { CoreCoord bank_core; uint32_t bank_id = page_id % num_banks; @@ -931,8 +931,9 @@ inline void gen_dispatcher_paged_write_cmd( uint32_t start_page, uint32_t page_size, uint32_t pages) { - uint32_t page_size_alignment_bytes = device->allocator()->get_config().alignment; - uint32_t num_banks = device->allocator()->get_num_banks(is_dram ? BufferType::DRAM : BufferType::L1); + BufferType buffer_type = is_dram ? BufferType::DRAM : BufferType::L1; + uint32_t page_size_alignment_bytes = device->allocator()->get_alignment(buffer_type); + uint32_t num_banks = device->allocator()->get_num_banks(buffer_type); CoreType core_type = is_dram ? CoreType::DRAM : CoreType::WORKER; // Not safe to mix paged L1 and paged DRAM writes currently in this test since same book-keeping. diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/pgm_dispatch_golden.json b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/pgm_dispatch_golden.json index 508b6faf624..7c26e13390b 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/pgm_dispatch_golden.json +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/pgm_dispatch_golden.json @@ -1748,6 +1748,186 @@ "time_unit": "ns", "IterationTime": 3.1806549999999998e-05 }, + { + "name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/256/manual_time", + "family_index": 18, + "per_family_instance_index": 0, + "run_name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/256/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 6, + "real_time": 1.1880933333333333e+08, + "cpu_time": 3.3488333333053786e+04, + "time_unit": "ns", + "IterationTime": 1.1880933333333333e-05 + }, + { + "name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/512/manual_time", + "family_index": 18, + "per_family_instance_index": 1, + "run_name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/512/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 6, + "real_time": 1.1882700000000000e+08, + "cpu_time": 3.7786666666761448e+04, + "time_unit": "ns", + "IterationTime": 1.1882700000000001e-05 + }, + { + "name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/1024/manual_time", + "family_index": 18, + "per_family_instance_index": 2, + "run_name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/1024/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 6, + "real_time": 1.1891783333333333e+08, + "cpu_time": 3.1728499999180334e+04, + "time_unit": "ns", + "IterationTime": 1.1891783333333332e-05 + }, + { + "name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/2048/manual_time", + "family_index": 18, + "per_family_instance_index": 3, + "run_name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/2048/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 6, + "real_time": 1.1946583333333336e+08, + "cpu_time": 2.6834999999891807e+04, + "time_unit": "ns", + "IterationTime": 1.1946583333333335e-05 + }, + { + "name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/4096/manual_time", + "family_index": 18, + "per_family_instance_index": 4, + "run_name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/4096/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 6, + "real_time": 1.2124800000000000e+08, + "cpu_time": 2.6059999999716863e+04, + "time_unit": "ns", + "IterationTime": 1.2124800000000001e-05 + }, + { + "name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/8192/manual_time", + "family_index": 18, + "per_family_instance_index": 5, + "run_name": "BM_pgm_dispatch/10000_kernel_all_cores_all_processors_32_cbs_trace/8192/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 4, + "real_time": 1.6583399999999997e+08, + "cpu_time": 2.6357500001239488e+04, + "time_unit": "ns", + "IterationTime": 1.6583399999999998e-05 + }, + { + "name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/256/manual_time", + "family_index": 19, + "per_family_instance_index": 0, + "run_name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/256/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 10, + "real_time": 6.8061800000000000e+07, + "cpu_time": 2.9687800000033349e+04, + "time_unit": "ns", + "IterationTime": 6.8061799999999988e-06 + }, + { + "name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/512/manual_time", + "family_index": 19, + "per_family_instance_index": 1, + "run_name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/512/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 10, + "real_time": 6.8067600000000000e+07, + "cpu_time": 2.2842899999631070e+04, + "time_unit": "ns", + "IterationTime": 6.8067600000000012e-06 + }, + { + "name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/1024/manual_time", + "family_index": 19, + "per_family_instance_index": 2, + "run_name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/1024/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 10, + "real_time": 6.8170400000000015e+07, + "cpu_time": 2.2918400000548900e+04, + "time_unit": "ns", + "IterationTime": 6.8170400000000012e-06 + }, + { + "name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/2048/manual_time", + "family_index": 19, + "per_family_instance_index": 3, + "run_name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/2048/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 10, + "real_time": 6.8726600000000000e+07, + "cpu_time": 2.5596999999777381e+04, + "time_unit": "ns", + "IterationTime": 6.8726600000000009e-06 + }, + { + "name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/4096/manual_time", + "family_index": 19, + "per_family_instance_index": 4, + "run_name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/4096/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 10, + "real_time": 7.0519899999999985e+07, + "cpu_time": 2.6065000000130567e+04, + "time_unit": "ns", + "IterationTime": 7.0519899999999989e-06 + }, + { + "name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/8192/manual_time", + "family_index": 19, + "per_family_instance_index": 5, + "run_name": "BM_pgm_dispatch/5000_kernel_all_cores_all_processors_32_cbs_trace/8192/manual_time", + "run_type": "iteration", + "repetitions": 1, + "repetition_index": 0, + "threads": 1, + "iterations": 6, + "real_time": 1.1566050000000000e+08, + "cpu_time": 3.1591666666959860e+04, + "time_unit": "ns", + "IterationTime": 1.1566049999999999e-05 + }, { "name": "BM_pgm_dispatch/kernel_groups_4_shadow/256/manual_time", "family_index": 18, diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp index 700672869db..bedd3d9d8f8 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_pgm_dispatch.cpp @@ -561,6 +561,19 @@ BENCHMARK_CAPTURE( TestInfo{.warmup_iterations = 5000, .n_kgs = 8, .use_trace = true, .use_all_cores = true}) ->Apply(Max8192Args) ->UseManualTime(); + +BENCHMARK_CAPTURE( + BM_pgm_dispatch, + 10000_kernel_all_cores_all_processors_32_cbs_trace, + TestInfo{.warmup_iterations = 5000, .slow_kernel_cycles = 10000, .n_cbs = 32, .use_trace = true, .use_all_cores = true}) + ->Apply(Max8192Args) + ->UseManualTime(); +BENCHMARK_CAPTURE( + BM_pgm_dispatch, + 5000_kernel_all_cores_all_processors_32_cbs_trace, + TestInfo{.warmup_iterations = 5000, .slow_kernel_cycles = 5000, .n_cbs = 32, .use_trace = true, .use_all_cores = true}) + ->Apply(Max8192Args) + ->UseManualTime(); int main(int argc, char** argv) { std::vector input_args(argv, argv + argc); if (test_args::has_command_option(input_args, "--custom")) { diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm.cpp index 95109747866..3a4ed7661f8 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/ethernet/test_ethernet_write_worker_latency_no_edm.cpp @@ -72,14 +72,14 @@ class N300TestDevice { bool device_open; }; -void validation(const std::shared_ptr& worker_buffer) { - std::vector golden_vec(worker_buffer->size(), 0); - std::vector result_vec(worker_buffer->size(), 0); +void validation(const std::shared_ptr& worker_buffer_0) { + std::vector golden_vec(worker_buffer_0->size(), 0); + std::vector result_vec(worker_buffer_0->size(), 0); - for (int i = 0; i < worker_buffer->size(); ++i) { + for (int i = 0; i < worker_buffer_0->size(); ++i) { golden_vec[i] = i; } - tt::tt_metal::detail::ReadFromBuffer(worker_buffer, result_vec); + tt::tt_metal::detail::ReadFromBuffer(worker_buffer_0, result_vec); bool pass = golden_vec == result_vec; TT_FATAL(pass, "validation failed"); @@ -94,9 +94,14 @@ std::vector build( std::size_t num_samples, std::size_t sample_page_size, std::size_t num_buffer_slots, + std::size_t num_directions, KernelHandle& local_kernel, KernelHandle& remote_kernel, - std::shared_ptr& worker_buffer) { + std::shared_ptr& worker_buffer_0, + std::shared_ptr& worker_buffer_1, + bool test_latency, + bool enable_worker, + bool disable_trid) { Program program0; Program program1; @@ -104,12 +109,15 @@ std::vector build( uint32_t worker_noc_x = device1->worker_core_from_logical_core(worker_core).x; uint32_t worker_noc_y = device1->worker_core_from_logical_core(worker_core).y; - uint32_t worker_buffer_addr = worker_buffer->address(); + uint32_t worker_buffer_0_addr = worker_buffer_0->address(); + uint32_t worker_buffer_1_addr = worker_buffer_1->address(); // eth core ct args - const std::vector& eth_sender_ct_args = {num_buffer_slots}; + const std::vector& eth_sender_ct_args = { + num_buffer_slots, worker_noc_x, worker_noc_y, worker_buffer_0_addr}; + const std::vector& eth_receiver_ct_args = { - num_buffer_slots, worker_noc_x, worker_noc_y, worker_buffer_addr}; + num_buffer_slots, worker_noc_x, worker_noc_y, worker_buffer_1_addr}; // eth core rt args const std::vector& eth_sender_receiver_rt_args = { @@ -117,12 +125,29 @@ std::vector build( static_cast(num_samples), static_cast(sample_page_size)}; + std::map sender_receiver_defines; + if (num_directions == 2) { + sender_receiver_defines["ENABLE_BI_DIRECTION"] = "1"; + } + if (test_latency) { + sender_receiver_defines["TEST_LATENCY"] = "1"; + } + if (enable_worker) { + sender_receiver_defines["ENABLE_WORKER"] = "1"; + } + if (disable_trid) { + sender_receiver_defines["DISABLE_TRID"] = "1"; + } + local_kernel = tt_metal::CreateKernel( program0, "tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/" "ethernet_write_worker_latency_ubench_sender.cpp", eth_sender_core, - tt_metal::EthernetConfig{.noc = tt_metal::NOC::RISCV_0_default, .compile_args = eth_sender_ct_args}); + tt_metal::EthernetConfig{ + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = eth_sender_ct_args, + .defines = sender_receiver_defines}); tt_metal::SetRuntimeArgs(program0, local_kernel, eth_sender_core, eth_sender_receiver_rt_args); remote_kernel = tt_metal::CreateKernel( @@ -130,7 +155,10 @@ std::vector build( "tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/" "ethernet_write_worker_latency_ubench_receiver.cpp", eth_receiver_core, - tt_metal::EthernetConfig{.noc = tt_metal::NOC::RISCV_0_default, .compile_args = eth_receiver_ct_args}); + tt_metal::EthernetConfig{ + .noc = tt_metal::NOC::RISCV_0_default, + .compile_args = eth_receiver_ct_args, + .defines = sender_receiver_defines}); tt_metal::SetRuntimeArgs(program1, remote_kernel, eth_receiver_core, eth_sender_receiver_rt_args); // Launch @@ -149,7 +177,14 @@ std::vector build( } void run( - IDevice* device0, IDevice* device1, Program& program0, Program& program1, std::shared_ptr& worker_buffer) { + IDevice* device0, + IDevice* device1, + Program& program0, + Program& program1, + std::size_t num_directions, + std::shared_ptr& worker_buffer_0, + std::shared_ptr& worker_buffer_1, + bool enable_worker) { if (std::getenv("TT_METAL_SLOW_DISPATCH_MODE")) { std::thread th2 = std::thread([&] { tt_metal::detail::LaunchProgram(device0, program0); }); std::thread th1 = std::thread([&] { tt_metal::detail::LaunchProgram(device1, program1); }); @@ -167,7 +202,12 @@ void run( tt::tt_metal::detail::DumpDeviceProfileResults(device0); tt::tt_metal::detail::DumpDeviceProfileResults(device1); - validation(worker_buffer); + if (enable_worker) { + validation(worker_buffer_1); + if (num_directions == 2) { + validation(worker_buffer_0); + } + } } int main(int argc, char** argv) { @@ -175,6 +215,11 @@ int main(int argc, char** argv) { std::size_t num_samples = std::stoi(argv[arg_idx++]); std::size_t sample_page_size = std::stoi(argv[arg_idx++]); std::size_t num_buffer_slots = std::stoi(argv[arg_idx++]); + std::size_t num_directions = std::stoi(argv[arg_idx++]); + bool test_latency = std::stoi(argv[arg_idx++]); + bool enable_worker = std::stoi(argv[arg_idx++]); + bool disable_trid = std::stoi(argv[arg_idx++]); + TT_FATAL(num_directions == 1 or num_directions == 2, "either uni-dir or bi-dir test"); auto arch = tt::get_arch_from_string(tt::test_utils::get_umd_arch_name()); auto num_devices = tt::tt_metal::GetNumAvailableDevices(); @@ -220,10 +265,11 @@ int main(int argc, char** argv) { try { log_info( tt::LogTest, - "num_samples: {}, sample_page_size: {}, num_buffer_slots: {}", + "num_samples: {}, sample_page_size: {}, num_buffer_slots: {}, num_directions: {}", num_samples, sample_page_size, - num_buffer_slots); + num_buffer_slots, + num_directions); KernelHandle local_kernel; KernelHandle remote_kernel; try { @@ -233,7 +279,13 @@ int main(int argc, char** argv) { ShardOrientation::ROW_MAJOR, {1, sample_page_size}, {1, sample_page_size}); - auto worker_buffer = CreateBuffer(tt::tt_metal::ShardedBufferConfig{ + auto worker_buffer_0 = CreateBuffer(tt::tt_metal::ShardedBufferConfig{ + .device = device_0, + .size = sample_page_size, + .page_size = sample_page_size, + .buffer_layout = TensorMemoryLayout::HEIGHT_SHARDED, + .shard_parameters = shard_spec}); + auto worker_buffer_1 = CreateBuffer(tt::tt_metal::ShardedBufferConfig{ .device = device_1, .size = sample_page_size, .page_size = sample_page_size, @@ -249,10 +301,22 @@ int main(int argc, char** argv) { num_samples, sample_page_size, num_buffer_slots, + num_directions, local_kernel, remote_kernel, - worker_buffer); - run(device_0, device_1, programs[0], programs[1], worker_buffer); + worker_buffer_0, + worker_buffer_1, + test_latency, + enable_worker, + disable_trid); + run(device_0, + device_1, + programs[0], + programs[1], + num_directions, + worker_buffer_0, + worker_buffer_1, + enable_worker); } catch (std::exception& e) { log_error(tt::LogTest, "Caught exception: {}", e.what()); test_fixture.TearDown(); diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_controller.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_controller.cpp new file mode 100644 index 00000000000..0b093070666 --- /dev/null +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_controller.cpp @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// clang-format off +#include "dataflow_api.h" +#include "debug/dprint.h" +#include "tests/tt_metal/tt_metal/perf_microbenchmark/common/kernel_utils.hpp" +// clang-format on + +void kernel_main() { + uint32_t rt_args_idx = 0; + uint32_t time_seed = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t num_tx_workers = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t tx_signal_addr = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t host_signal_address = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t num_mcast_dests = get_arg_val(increment_arg_idx(rt_args_idx)); + uint32_t mcast_encoding = get_arg_val(increment_arg_idx(rt_args_idx)); + + // wait for sync from tx kernels + while (*(volatile tt_l1_ptr uint32_t*)tx_signal_addr != num_tx_workers); + + // wait for signal from host + // this is needed to know that all the routers are up and running on all the chips + while (*(volatile tt_l1_ptr uint32_t*)host_signal_address == 0); + + tt_l1_ptr uint32_t* mcast_sem = reinterpret_cast(0x100000); + *mcast_sem = 1; + + // do a noc multicast to tx kernels + uint64_t mcast_dest_addr = get_noc_addr_helper(mcast_encoding, tx_signal_addr); + noc_async_write_multicast_one_packet((uint32_t)mcast_sem, mcast_dest_addr, sizeof(uint32_t), num_mcast_dests); +} diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp index a70eea475bd..152f52e5767 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_tx.cpp @@ -17,7 +17,6 @@ using namespace tt::tt_fabric; uint32_t src_endpoint_id; // constexpr uint32_t src_endpoint_id = get_compile_time_arg_val(0); constexpr uint32_t num_dest_endpoints = get_compile_time_arg_val(1); -static_assert(is_power_of_2(num_dest_endpoints), "num_dest_endpoints must be a power of 2"); constexpr uint32_t dest_endpoint_start_id = get_compile_time_arg_val(2); constexpr uint32_t data_buffer_start_addr = get_compile_time_arg_val(3); @@ -61,6 +60,12 @@ constexpr uint32_t client_interface_addr = get_compile_time_arg_val(20); constexpr bool fixed_async_wr_notif_addr = get_compile_time_arg_val(22); +constexpr bool mcast_data = get_compile_time_arg_val(23); +constexpr uint32_t e_depth = get_compile_time_arg_val(24); +constexpr uint32_t w_depth = get_compile_time_arg_val(25); +constexpr uint32_t n_depth = get_compile_time_arg_val(26); +constexpr uint32_t s_depth = get_compile_time_arg_val(27); + uint32_t max_packet_size_mask; auto input_queue_state = select_input_queue(); @@ -83,11 +88,28 @@ uint32_t rx_addr_hi; uint32_t gk_interface_addr_l; uint32_t gk_interface_addr_h; +uint32_t controller_noc_offset; + // flag to check if need to zero out notification addr bool reset_notif_addr = true; uint32_t time_seed; +inline void notify_traffic_controller() { + // send semaphore increment to traffic controller kernel on this device. + uint64_t dest_addr = get_noc_addr_helper(controller_noc_offset, signal_address); + noc_fast_atomic_increment( + noc_index, + NCRISC_AT_CMD_BUF, + dest_addr, + NOC_UNICAST_WRITE_VC, + 1, + 31, + false, + false, + MEM_NOC_ATOMIC_RET_VAL_ADDR); +} + // generates packets with random size and payload on the input sideß inline bool test_buffer_handler_async_wr() { if (input_queue_state.all_packets_done()) { @@ -128,7 +150,7 @@ inline bool test_buffer_handler_async_wr() { target_address = base_target_address; } - packet_header.routing.flags = FORWARD; + packet_header.routing.flags = FORWARD | (mcast_data ? MCAST_DATA : 0); packet_header.routing.packet_size_bytes = input_queue_state.curr_packet_size_words * PACKET_WORD_SIZE_BYTES; packet_header.routing.dst_mesh_id = dest_device >> 16; packet_header.routing.dst_dev_id = dest_device & 0xFFFF; @@ -147,6 +169,12 @@ inline bool test_buffer_handler_async_wr() { packet_header.session.target_offset_l = target_address; packet_header.session.target_offset_h = noc_offset; target_address += packet_header.routing.packet_size_bytes - PACKET_HEADER_SIZE_BYTES; + if constexpr (mcast_data) { + packet_header.packet_parameters.mcast_parameters.east = e_depth; + packet_header.packet_parameters.mcast_parameters.west = w_depth; + packet_header.packet_parameters.mcast_parameters.north = n_depth; + packet_header.packet_parameters.mcast_parameters.south = s_depth; + } tt_fabric_add_header_checksum(&packet_header); uint32_t words_left = words_to_init - words_initialized; bool split_header = words_left < PACKET_HEADER_SIZE_WORDS; @@ -363,6 +391,7 @@ void kernel_main() { time_seed = get_arg_val(increment_arg_idx(rt_args_idx)); src_endpoint_id = get_arg_val(increment_arg_idx(rt_args_idx)); noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); + controller_noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t router_x = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t router_y = get_arg_val(increment_arg_idx(rt_args_idx)); dest_device = get_arg_val(increment_arg_idx(rt_args_idx)); @@ -405,7 +434,7 @@ void kernel_main() { input_queue_state.init(src_endpoint_id, prng_seed); } - test_producer.init(data_buffer_start_addr, data_buffer_size_words, 0x0); + test_producer.init(data_buffer_start_addr, data_buffer_size_words); fvcc_test_producer.init(data_buffer_start_addr, 0x0, 0x0); uint32_t temp = max_packet_size_words; @@ -421,9 +450,13 @@ void kernel_main() { max_packet_size_mask = (max_packet_size_mask << 1) + 1; } - // wait till test sends start signal. This is set by test - // once tt_fabric kernels have been launched on all the test devices. - while (*(tt_l1_ptr volatile uint32_t*)signal_address == 0); + // notify the controller kernel that this worker is ready to proceed + notify_traffic_controller(); + + // wait till controllrer sends start signal. This is set by controller + // once tt_fabric kernels have been launched on all the test devices and + // all the tx workers are ready on this chip + while (*(volatile tt_l1_ptr uint32_t*)signal_address == 0); test_results[PQ_TEST_MISC_INDEX] = 0xff000001; diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp index a3ea542cd51..d749c799ec8 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_tx_ubench.cpp @@ -18,7 +18,6 @@ using namespace tt::tt_fabric; uint32_t src_endpoint_id; // constexpr uint32_t src_endpoint_id = get_compile_time_arg_val(0); constexpr uint32_t num_dest_endpoints = get_compile_time_arg_val(1); -static_assert(is_power_of_2(num_dest_endpoints), "num_dest_endpoints must be a power of 2"); constexpr uint32_t dest_endpoint_start_id = get_compile_time_arg_val(2); constexpr uint32_t data_buffer_start_addr = get_compile_time_arg_val(3); @@ -60,6 +59,12 @@ uint32_t dest_device; constexpr uint32_t signal_address = get_compile_time_arg_val(19); constexpr uint32_t client_interface_addr = get_compile_time_arg_val(20); +constexpr bool mcast_data = get_compile_time_arg_val(23); +constexpr uint32_t e_depth = get_compile_time_arg_val(24); +constexpr uint32_t w_depth = get_compile_time_arg_val(25); +constexpr uint32_t n_depth = get_compile_time_arg_val(26); +constexpr uint32_t s_depth = get_compile_time_arg_val(27); + volatile local_pull_request_t* local_pull_request = (volatile local_pull_request_t*)(data_buffer_start_addr - 1024); volatile tt_l1_ptr fabric_router_l1_config_t* routing_table = reinterpret_cast(routing_table_start_addr); @@ -70,9 +75,24 @@ uint32_t target_address; uint32_t noc_offset; uint32_t gk_interface_addr_l; uint32_t gk_interface_addr_h; - +uint32_t controller_noc_offset; uint32_t time_seed; +inline void notify_traffic_controller() { + // send semaphore increment to traffic controller kernel on this device. + uint64_t dest_addr = get_noc_addr_helper(controller_noc_offset, signal_address); + noc_fast_atomic_increment( + noc_index, + NCRISC_AT_CMD_BUF, + dest_addr, + NOC_UNICAST_WRITE_VC, + 1, + 31, + false, + false, + MEM_NOC_ATOMIC_RET_VAL_ADDR); +} + void kernel_main() { tt_fabric_init(); @@ -80,13 +100,19 @@ void kernel_main() { time_seed = get_arg_val(increment_arg_idx(rt_args_idx)); src_endpoint_id = get_arg_val(increment_arg_idx(rt_args_idx)); noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); + controller_noc_offset = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t router_x = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t router_y = get_arg_val(increment_arg_idx(rt_args_idx)); dest_device = get_arg_val(increment_arg_idx(rt_args_idx)); uint32_t rx_buf_size = get_arg_val(increment_arg_idx(rt_args_idx)); gk_interface_addr_l = get_arg_val(increment_arg_idx(rt_args_idx)); gk_interface_addr_h = get_arg_val(increment_arg_idx(rt_args_idx)); - target_address = get_arg_val(increment_arg_idx(rt_args_idx)); + + if constexpr (ASYNC_WR & test_command) { + base_target_address = get_arg_val(increment_arg_idx(rt_args_idx)); + } + + target_address = base_target_address; // Read in the routing table uint64_t router_config_addr = @@ -113,20 +139,36 @@ void kernel_main() { uint32_t packet_count = 0; uint64_t dst_addr = ((uint64_t)noc_offset << 32 | target_address); - - fabric_async_write_add_header( - data_buffer_start_addr, // source address in sender’s memory - dest_device >> 16, - dest_device & 0xFFFF, - dst_addr, // destination write address - max_packet_size_words * 16 // number of bytes to write to remote destination - ); + if constexpr (mcast_data) { + fabric_async_write_multicast_add_header( + data_buffer_start_addr, // source address in sender’s memory + dest_device >> 16, + dest_device & 0xFFFF, + dst_addr, // destination write address + max_packet_size_words * 16, // number of bytes to write to remote destination + e_depth, + w_depth, + n_depth, + s_depth); + } else { + fabric_async_write_add_header( + data_buffer_start_addr, // source address in sender’s memory + dest_device >> 16, + dest_device & 0xFFFF, + dst_addr, // destination write address + max_packet_size_words * 16 // number of bytes to write to remote destination + ); + } // make sure fabric node gatekeeper is available. fabric_endpoint_init(); + // notify the controller kernel that this worker is ready to proceed + notify_traffic_controller(); + // wait till test sends start signal. This is set by test - // once tt_fabric kernels have been launched on all the test devices. + // once tt_fabric kernels have been launched on all the test devices and + // all tx workers are ready to send data while (*(volatile tt_l1_ptr uint32_t*)signal_address == 0); uint64_t start_timestamp = get_timestamp(); @@ -136,19 +178,35 @@ void kernel_main() { ); while (true) { - client_interface->local_pull_request.pull_request.rd_ptr = 0; - fabric_async_write( - 0, // the network plane to use for this transaction - data_buffer_start_addr, // source address in sender’s memory - dest_device >> 16, - dest_device & 0xFFFF, - dst_addr, // destination write address - max_packet_size_words * 16 // number of bytes to write to remote destination - ); + client_interface->local_pull_request.pull_request.words_read = 0; + if constexpr (mcast_data) { + fabric_async_write_multicast( + 0, // the network plane to use for this transaction + data_buffer_start_addr, // source address in sender’s memory + dest_device >> 16, + dest_device & 0xFFFF, + dst_addr, // destination write address + max_packet_size_words * 16, // number of bytes to write to remote destination + e_depth, + w_depth, + n_depth, + s_depth + ); + } else { + fabric_async_write( + 0, // the network plane to use for this transaction + data_buffer_start_addr, // source address in sender’s memory + dest_device >> 16, + dest_device & 0xFFFF, + dst_addr, // destination write address + max_packet_size_words * 16 // number of bytes to write to remote destination + ); + } + data_words_sent += max_packet_size_words; packet_count++; - uint32_t wr_ptr = client_interface->local_pull_request.pull_request.wr_ptr; - while (client_interface->local_pull_request.pull_request.rd_ptr != wr_ptr) { + uint32_t words_written = client_interface->local_pull_request.pull_request.words_written; + while (client_interface->local_pull_request.pull_request.words_read != words_written) { #pragma GCC unroll 4 for (int i = 0; i < 4; i++) { asm("nop"); diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp index 0bcc7e273a4..233f9530438 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/test_tt_fabric_sanity.cpp @@ -50,9 +50,14 @@ bool bidirectional_traffic; // benchmark test mode bool benchmark_mode; +uint32_t tx_signal_address; +uint32_t host_signal_address; + // kernels const std::string gatekeeper_kernel_src = "tt_fabric/impl/kernels/tt_fabric_gatekeeper.cpp"; const std::string router_kernel_src = "tt_fabric/impl/kernels/tt_fabric_router.cpp"; +const std::string traffic_controller_src = + "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_controller.cpp"; const std::string rx_kernel_src = "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/tt_fabric_traffic_gen_rx.cpp"; std::string tx_kernel_src; @@ -87,7 +92,7 @@ inline std::vector get_random_numbers_from_range(uint32_t start, uint3 typedef struct test_board { std::vector available_chip_ids; std::vector physical_chip_ids; - std::vector> unicast_map; + std::vector>> tx_rx_map; std::map device_handle_map; std::unique_ptr control_plane; uint32_t num_chips_to_use; @@ -223,7 +228,41 @@ typedef struct test_board { } } - void generate_unicast_map(uint32_t num_hops) { + // TODO: This only supports 1d mcast right now, needs to be updated to support 2D mcast + // Note that this currently only considers intra-mesh mcast + // physical_start_chip_id here refers to the sender, not the mcast origin due to how we count depth + std::vector get_physical_mcast_chip_ids( + chip_id_t physical_start_chip_id, const std::unordered_map& mcast_depth) { + std::vector physical_dsts; + // APIs use mesh chip id, so convert physical chip id to mesh chip id + auto [mesh_id, chip_id] = this->get_mesh_chip_id(physical_start_chip_id); + bool valid = true; + for (const auto& [routing_direction, num_hops_in_direction] : mcast_depth) { + for (auto j = 0; j < num_hops_in_direction; j++) { + auto neighbors = this->get_intra_chip_neighbors(mesh_id, chip_id, routing_direction); + if (neighbors.empty()) { + valid = false; + break; + } + // Assumes all neighbors are the same chip + chip_id = neighbors[0]; + // convert mesh chip id to physical chip id + physical_dsts.push_back( + this->control_plane->get_physical_chip_id_from_mesh_chip_id({mesh_id, chip_id})); + } + if (!valid) { + break; + } + } + if (valid) { + return physical_dsts; + } else { + return {}; + } + } + + void generate_tx_rx_map( + uint32_t num_hops, bool mcast, const std::unordered_map& mcast_depth) { std::unordered_map> chip_neighbors; std::unordered_map> chip_n_hop_neighbors; std::vector> n_hop_neighbors_cnt; @@ -236,8 +275,20 @@ typedef struct test_board { // for default setting, generate a random unicast map if (DEFAULT_NUM_HOPS == num_hops) { + if (mcast) { + for (auto i = 0; i < physical_chip_ids.size(); i++) { + auto physical_mcast_chip_ids = this->get_physical_mcast_chip_ids(physical_chip_ids[i], mcast_depth); + if (!physical_mcast_chip_ids.empty()) { + tx_rx_map.push_back({physical_chip_ids[i], std::move(physical_mcast_chip_ids)}); + // Generate only one mcast for now to avoid overlapping cores + break; + } + } + TT_FATAL(!tx_rx_map.empty(), "Failed to generate multicast map"); + return; + } for (auto i = 0; i < physical_chip_ids.size(); i += 2) { - unicast_map.push_back({physical_chip_ids[i], physical_chip_ids[i + 1]}); + tx_rx_map.push_back({physical_chip_ids[i], {physical_chip_ids[i + 1]}}); } return; } @@ -340,7 +391,17 @@ typedef struct test_board { // throw std::runtime_error("No neighbor found for this chip"); } - unicast_map.push_back({chip_id, selected_chip_id}); + if (mcast) { + // TODO: This assumes line mcast from neighbor with 1 hop + auto physical_mcast_chip_ids = this->get_physical_mcast_chip_ids(chip_id, mcast_depth); + if (!physical_mcast_chip_ids.empty() && (physical_mcast_chip_ids[0] == selected_chip_id)) { + tx_rx_map.push_back({chip_id, std::move(physical_mcast_chip_ids)}); + } else { + continue; + } + } else { + tx_rx_map.push_back({chip_id, {selected_chip_id}}); + } // remove selected chip as it should not be picked again chip_n_hop_neighbors.erase(selected_chip_id); @@ -348,6 +409,12 @@ typedef struct test_board { // remove the entry for current chip as it should not be picked again chip_n_hop_neighbors.erase(chip_id); } + + // error out if no valid tx rx mapping was found + // We should only be able to hit this assertion when looking for mcast destinations + if (!tx_rx_map.size()) { + throw std::runtime_error("No valid tx rx mapping found"); + } } inline uint32_t get_num_available_devices() { return physical_chip_ids.size(); } @@ -378,6 +445,11 @@ typedef struct test_board { return control_plane->get_fabric_route(src_mesh_id, src_chip_id, dst_mesh_id, dst_chip_id, src_chan_id); } + inline std::vector get_intra_chip_neighbors( + mesh_id_t src_mesh_id, chip_id_t src_chip_id, RoutingDirection routing_direction) { + return control_plane->get_intra_chip_neighbors(src_mesh_id, src_chip_id, routing_direction); + } + inline void close_devices() { tt::tt_metal::detail::CloseDevices(device_handle_map); } } test_board_t; @@ -390,6 +462,8 @@ typedef struct test_device { std::vector worker_logical_cores; std::vector router_logical_cores; std::vector router_virtual_cores; + CoreCoord core_range_start_virtual; + CoreCoord core_range_end_virtual; CoreCoord gk_logical_core; CoreCoord gk_phys_core; mesh_id_t mesh_id; @@ -419,6 +493,9 @@ typedef struct test_device { } } + core_range_start_virtual = device_handle->worker_core_from_logical_core(CoreCoord(0, 0)); + core_range_end_virtual = device_handle->worker_core_from_logical_core(CoreCoord(7, 7)); + // populate router cores auto neighbors = tt::Cluster::instance().get_ethernet_connected_device_ids(physical_chip_id); for (auto neighbor : neighbors) { @@ -718,11 +795,15 @@ typedef struct test_device { } } + inline std::vector get_intra_chip_neighbors(RoutingDirection routing_direction) { + return board_handle->get_intra_chip_neighbors(mesh_id, logical_chip_id, routing_direction); + } + } test_device_t; typedef struct test_traffic { std::shared_ptr tx_device; - std::shared_ptr rx_device; + std::vector> rx_devices; uint32_t num_tx_workers; uint32_t num_rx_workers; uint32_t target_address; @@ -730,12 +811,14 @@ typedef struct test_traffic { std::vector> rx_workers; std::vector tx_virtual_cores; std::vector rx_virtual_cores; + CoreCoord controller_logical_core; + CoreCoord controller_virtual_core; std::vector tx_to_rx_map; std::vector> rx_to_tx_map; std::vector tx_to_rx_address_map; std::vector> rx_to_tx_address_map; std::vector> tx_results; - std::vector> rx_results; + std::vector>> rx_results; uint32_t test_results_address; uint32_t rx_buf_size; uint32_t num_links_to_use; @@ -743,14 +826,14 @@ typedef struct test_traffic { test_traffic( std::shared_ptr& tx_device_, - std::shared_ptr& rx_device_, + std::vector>& rx_devices_, uint32_t num_src_endpoints, uint32_t num_dest_endpoints, uint32_t target_address_, uint32_t num_hops, uint32_t num_links_) { tx_device = tx_device_; - rx_device = rx_device_; + rx_devices = rx_devices_; num_tx_workers = num_src_endpoints; num_rx_workers = num_dest_endpoints; target_address = target_address_; @@ -767,7 +850,9 @@ typedef struct test_traffic { std::vector src_routers; std::vector dest_routers; - tx_device->get_available_router_cores(num_hops, rx_device, src_routers, dest_routers); + // For Unicast there is only one rx device + // For mcast, this only supports line mcast, we pass the last device as the rx device + tx_device->get_available_router_cores(num_hops, *rx_devices.rbegin(), src_routers, dest_routers); num_links_to_use = std::min(num_links_, (uint32_t)src_routers.size()); _generate_tx_to_rx_mapping(); @@ -779,14 +864,19 @@ typedef struct test_traffic { // for bi-directional traffic leave the higher priority cores on the rx chip for tx kernels num_cores_to_skip = (num_rx_workers + num_links_to_use - 1) / num_links_to_use; } - rx_workers = rx_device->select_worker_cores(dest_routers, num_links_to_use, num_rx_workers, num_cores_to_skip); + // Assumes uniform worker grid across receiver chips + rx_workers = rx_devices[0]->select_worker_cores(dest_routers, num_links_to_use, num_rx_workers, num_cores_to_skip); + + // TODO: not the most optimum selection, might impact somewhat in bidirectional mode + controller_logical_core = tx_device->select_random_worker_cores(1)[0]; + controller_virtual_core = tx_device->device_handle->worker_core_from_logical_core(controller_logical_core); for (auto& [router, noc, worker] : tx_workers) { tx_virtual_cores.push_back(tx_device->device_handle->worker_core_from_logical_core(worker)); } for (auto& [router, noc, worker] : rx_workers) { - rx_virtual_cores.push_back(rx_device->device_handle->worker_core_from_logical_core(worker)); + rx_virtual_cores.push_back(rx_devices[0]->device_handle->worker_core_from_logical_core(worker)); } } @@ -795,17 +885,52 @@ typedef struct test_traffic { std::vector& rx_compile_args, std::map& defines, uint32_t fabric_command, - uint32_t tx_signal_address, uint32_t test_results_address_) { CoreCoord tx_core, rx_core; tt_metal::NOC noc_id; std::vector zero_buf(2, 0); CoreCoord router_virtual_core; - uint32_t mesh_chip_id = rx_device->mesh_chip_id; + uint32_t mesh_chip_id = rx_devices[0]->mesh_chip_id; // update the test results address, which will be used later for polling, collecting results test_results_address = test_results_address_; + { + uint32_t mcast_encoding = tt::tt_metal::hal.noc_multicast_encoding( + tx_device->core_range_start_virtual.x, + tx_device->core_range_start_virtual.y, + tx_device->core_range_end_virtual.x, + tx_device->core_range_end_virtual.y); + + // launch controller kernel + // TODO: remove hardcoding + std::vector runtime_args = { + time_seed, // 0: time based seed + num_tx_workers, // 1: number of workers for mcast + tx_signal_address, // 2: address to send signal on to workers + host_signal_address, // 3: address to receive signal from host + 64, // 4: num mcast dest + mcast_encoding, // 5: mcast dest noc encoding + }; + + // zero out the signal address + tt::llrt::write_hex_vec_to_core( + tx_device->physical_chip_id, controller_virtual_core, zero_buf, tx_signal_address); + + // zero out host sync address + tt::llrt::write_hex_vec_to_core( + tx_device->physical_chip_id, controller_virtual_core, zero_buf, host_signal_address); + + auto kernel = tt_metal::CreateKernel( + tx_device->program_handle, + traffic_controller_src, + {controller_logical_core}, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}); + + tt_metal::SetRuntimeArgs(tx_device->program_handle, kernel, controller_logical_core, runtime_args); + } + // launch tx kernels for (auto i = 0; i < num_tx_workers; i++) { router_virtual_core = std::get<0>(tx_workers[i]); @@ -815,15 +940,16 @@ typedef struct test_traffic { // setup runtime args std::vector runtime_args = { - time_seed, // 0: time based seed - tx_device->get_endpoint_id(tx_core), // 1: src_endpoint_id - rx_device->get_noc_offset(rx_core), // 2: dest_noc_offset - router_virtual_core.x, // 3: router_x - router_virtual_core.y, // 4: router_y - mesh_chip_id, // 5: mesh and chip id - rx_buf_size, // 6: space in rx's L1 - gk_interface_addr, // 7: gk_message_addr_l - tx_device->gk_noc_offset, // 8: gk_message_addr_h + time_seed, // 0: time based seed + tx_device->get_endpoint_id(tx_core), // 1: src_endpoint_id + rx_devices[0]->get_noc_offset(rx_core), // 2: dest_noc_offset + tx_device->get_noc_offset(controller_logical_core), // 3: controller noc offset + router_virtual_core.x, // 4: router_x + router_virtual_core.y, // 5: router_y + mesh_chip_id, // 6: mesh and chip id + rx_buf_size, // 7: space in rx's L1 + gk_interface_addr, // 8: gk_message_addr_l + tx_device->gk_noc_offset, // 9: gk_message_addr_h }; if (ASYNC_WR & fabric_command) { @@ -832,11 +958,12 @@ typedef struct test_traffic { // zero out the signal address tt::llrt::write_hex_vec_to_core( - tx_device->device_handle->id(), tx_virtual_cores[i], zero_buf, tx_signal_address); + tx_device->physical_chip_id, tx_virtual_cores[i], zero_buf, tx_signal_address); log_info( LogTest, - "run traffic_gen_tx at logical: x={},y={}; virtual: x={},y={}", + "Device: {}, TX kernel running on: logical: x={},y={}; virtual: x={},y={}", + tx_device->physical_chip_id, tx_core.x, tx_core.y, tx_virtual_cores[i].x, @@ -877,49 +1004,62 @@ typedef struct test_traffic { runtime_args.push_back(address); } } else if (ATOMIC_INC == fabric_command) { - tt::llrt::write_hex_vec_to_core( - rx_device->device_handle->id(), rx_virtual_cores[i], zero_buf, target_address); + for (const auto& rx_device : rx_devices) { + tt::llrt::write_hex_vec_to_core( + rx_device->physical_chip_id, rx_virtual_cores[i], zero_buf, target_address); + } } - // zero out the test results address, which will be used for polling - tt::llrt::write_hex_vec_to_core( - rx_device->device_handle->id(), rx_virtual_cores[i], zero_buf, test_results_address); - - log_info( - LogTest, - "run traffic_gen_rx at logical: x={},y={}; virtual: x={},y={}", - rx_core.x, - rx_core.y, - rx_virtual_cores[i].x, - rx_virtual_cores[i].y); - auto kernel = tt_metal::CreateKernel( - rx_device->program_handle, - rx_kernel_src, - {rx_core}, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_0, - .noc = noc_id, - .compile_args = rx_compile_args, - .defines = defines}); - - tt_metal::SetRuntimeArgs(rx_device->program_handle, kernel, rx_core, runtime_args); + for (const auto& rx_device : rx_devices) { + // zero out the test results address, which will be used for polling + tt::llrt::write_hex_vec_to_core( + rx_device->physical_chip_id, rx_virtual_cores[i], zero_buf, test_results_address); + + log_info( + LogTest, + "Device: {}, RX kernel running on: logical: x={},y={}; virtual: x={},y={}", + rx_device->physical_chip_id, + rx_core.x, + rx_core.y, + rx_virtual_cores[i].x, + rx_virtual_cores[i].y); + auto kernel = tt_metal::CreateKernel( + rx_device->program_handle, + rx_kernel_src, + {rx_core}, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = noc_id, + .compile_args = rx_compile_args, + .defines = defines}); + + tt_metal::SetRuntimeArgs(rx_device->program_handle, kernel, rx_core, runtime_args); + } } } + void notify_tx_controller() { + std::vector start_signal(1, 1); + tt::llrt::write_hex_vec_to_core( + tx_device->physical_chip_id, controller_virtual_core, start_signal, host_signal_address); + } + void notify_tx_workers(uint32_t address) { std::vector start_signal(1, 1); for (auto core : tx_virtual_cores) { - tt::llrt::write_hex_vec_to_core(tx_device->device_handle->id(), core, start_signal, address); + tt::llrt::write_hex_vec_to_core(tx_device->physical_chip_id, core, start_signal, address); } } void wait_for_rx_workers_to_finish() { - for (auto& rx_core : rx_virtual_cores) { - while (true) { - auto tx_status = - tt::llrt::read_hex_vec_from_core(rx_device->device_handle->id(), rx_core, test_results_address, 4); - if ((tx_status[0] & 0xFFFF) != 0) { - break; + for (const auto& rx_device : rx_devices) { + for (auto& rx_core : rx_virtual_cores) { + while (true) { + auto tx_status = + tt::llrt::read_hex_vec_from_core(rx_device->physical_chip_id, rx_core, test_results_address, 4); + if ((tx_status[0] & 0xFFFF) != 0) { + break; + } } } } @@ -931,25 +1071,30 @@ typedef struct test_traffic { // collect tx results for (uint32_t i = 0; i < num_tx_workers; i++) { tx_results.push_back(tt::llrt::read_hex_vec_from_core( - tx_device->device_handle->id(), tx_virtual_cores[i], test_results_address, 128)); + tx_device->physical_chip_id, tx_virtual_cores[i], test_results_address, 128)); log_info( LogTest, - "TX{} status = {}", + "Device {} TX{} status = {}", + tx_device->physical_chip_id, i, packet_queue_test_status_to_string(tx_results[i][PQ_TEST_STATUS_INDEX])); pass &= (tx_results[i][PQ_TEST_STATUS_INDEX] == PACKET_QUEUE_TEST_PASS); } // collect rx results - for (uint32_t i = 0; i < num_rx_workers; i++) { - rx_results.push_back(tt::llrt::read_hex_vec_from_core( - rx_device->device_handle->id(), rx_virtual_cores[i], test_results_address, 128)); - log_info( - LogTest, - "RX{} status = {}", - i, - packet_queue_test_status_to_string(rx_results[i][PQ_TEST_STATUS_INDEX])); - pass &= (rx_results[i][PQ_TEST_STATUS_INDEX] == PACKET_QUEUE_TEST_PASS); + rx_results.resize(rx_devices.size()); + for (uint32_t d = 0; d < rx_devices.size(); d++) { + for (uint32_t i = 0; i < num_rx_workers; i++) { + rx_results[d].push_back(tt::llrt::read_hex_vec_from_core( + rx_devices[d]->physical_chip_id, rx_virtual_cores[i], test_results_address, 128)); + log_info( + LogTest, + "Device {} RX{} status = {}", + rx_devices[d]->physical_chip_id, + i, + packet_queue_test_status_to_string(rx_results[d][i][PQ_TEST_STATUS_INDEX])); + pass &= (rx_results[d][i][PQ_TEST_STATUS_INDEX] == PACKET_QUEUE_TEST_PASS); + } } return pass; @@ -960,19 +1105,21 @@ typedef struct test_traffic { uint64_t num_tx_words, num_tx_packets; // tally-up data words and number of packets from rx and tx - for (uint32_t i = 0; i < num_rx_workers; i++) { - num_tx_words = 0; - num_tx_packets = 0; + for (uint32_t d = 0; d < rx_devices.size(); d++) { + for (uint32_t i = 0; i < num_rx_workers; i++) { + num_tx_words = 0; + num_tx_packets = 0; - for (auto j : rx_to_tx_map[i]) { - num_tx_words += get_64b_result(tx_results[j], PQ_TEST_WORD_CNT_INDEX); - num_tx_packets += get_64b_result(tx_results[j], TX_TEST_IDX_NPKT); - } - pass &= (get_64b_result(rx_results[i], PQ_TEST_WORD_CNT_INDEX) == num_tx_words); - pass &= (get_64b_result(rx_results[i], TX_TEST_IDX_NPKT) == num_tx_packets); + for (auto j : rx_to_tx_map[i]) { + num_tx_words += get_64b_result(tx_results[j], PQ_TEST_WORD_CNT_INDEX); + num_tx_packets += get_64b_result(tx_results[j], TX_TEST_IDX_NPKT); + } + pass &= (get_64b_result(rx_results[d][i], PQ_TEST_WORD_CNT_INDEX) == num_tx_words); + pass &= (get_64b_result(rx_results[d][i], TX_TEST_IDX_NPKT) == num_tx_packets); - if (!pass) { - break; + if (!pass) { + break; + } } } @@ -1002,7 +1149,8 @@ typedef struct test_traffic { log_info( LogTest, - "TX {} words sent = {}, elapsed cycles = {} -> BW = {:.2f} B/cycle", + "Device: {}, TX {} words sent: {}, elapsed cycles: {} -> BW: {:.2f} B/cycle", + tx_device->physical_chip_id, i, tx_words_sent, tx_elapsed_cycles, @@ -1022,10 +1170,18 @@ typedef struct test_traffic { */ } total_tx_bw_2 = ((double)total_tx_words_sent) * PACKET_WORD_SIZE_BYTES / max_tx_elapsed_cycles; - for (uint32_t i = 0; i < num_rx_workers; i++) { - uint64_t words_received = get_64b_result(rx_results[i], PQ_TEST_WORD_CNT_INDEX); - uint32_t num_tx = rx_to_tx_map[i].size(); - log_info(LogTest, "RX {}, num producers = {}, words received = {}", i, num_tx, words_received); + for (uint32_t d = 0; d < rx_devices.size(); d++) { + for (uint32_t i = 0; i < num_rx_workers; i++) { + uint64_t words_received = get_64b_result(rx_results[d][i], PQ_TEST_WORD_CNT_INDEX); + uint32_t num_tx = rx_to_tx_map[i].size(); + log_info( + LogTest, + "Device: {}, RX {}, num producers = {}, words received = {}", + rx_devices[d]->physical_chip_id, + i, + num_tx, + words_received); + } } // log_info(LogTest, "Total TX BW = {:.2f} B/cycle", total_tx_bw); log_info(LogTest, "Total TX BW = {:.2f} B/cycle", total_tx_bw_2); @@ -1110,7 +1266,10 @@ int main(int argc, char **argv) { constexpr uint32_t default_tx_queue_size_bytes = 0x10000; constexpr uint32_t default_rx_queue_start_addr = 0xa0000; constexpr uint32_t default_rx_queue_size_bytes = 0x20000; - constexpr uint32_t default_tx_signal_address = 0x70000; + + // if this is used for multicast on all workers, carefully set it to a value that + // doesnt interfere with rx payload checking + constexpr uint32_t default_tx_signal_address = 0x28000; constexpr uint32_t default_test_results_addr = 0x100000; constexpr uint32_t default_test_results_size = 0x40000; @@ -1150,6 +1309,8 @@ int main(int argc, char **argv) { constexpr uint32_t default_atomic_increment = 4; + constexpr uint32_t default_multicast = 0; + constexpr const char* default_board_type = "glx32"; constexpr uint32_t default_num_traffic_devices = 0; @@ -1162,6 +1323,8 @@ int main(int argc, char **argv) { constexpr uint32_t default_packet_size_kb = 4; + constexpr uint32_t default_host_signal_address = 0x60000; + std::vector input_args(argv, argv + argc); if (test_args::has_command_option(input_args, "-h") || test_args::has_command_option(input_args, "--help")) { @@ -1239,6 +1402,24 @@ int main(int argc, char **argv) { uint32_t atomic_increment = test_args::get_command_option_uint32(input_args, "--atomic_increment", default_atomic_increment); + // Note here that currently mcast_depth considers the mcast origin as a hop, and not the distance from the origin + // This has side effects that specifying a depth of 0 or 1 will result in the same behavior + std::unordered_map mcast_depth; + mcast_depth[RoutingDirection::E] = test_args::get_command_option_uint32(input_args, "--e_depth", default_multicast); + mcast_depth[RoutingDirection::W] = test_args::get_command_option_uint32(input_args, "--w_depth", default_multicast); + mcast_depth[RoutingDirection::N] = test_args::get_command_option_uint32(input_args, "--n_depth", default_multicast); + mcast_depth[RoutingDirection::S] = test_args::get_command_option_uint32(input_args, "--s_depth", default_multicast); + bool mcast = false; + for (const auto& [dir, depth] : mcast_depth) { + if (depth) { + // TODO: Remove once generic mcast is supported + if (mcast) { + throw std::runtime_error("Only 1 mcast direction is supported right now"); + } + mcast = true; + } + } + // assert((pkt_dest_size_choices_t)tx_pkt_dest_size_choice == pkt_dest_size_choices_t::SAME_START_RNDROBIN_FIX_SIZE // && rx_disable_header_check || (pkt_dest_size_choices_t)tx_pkt_dest_size_choice == // pkt_dest_size_choices_t::RANDOM); @@ -1255,7 +1436,8 @@ int main(int argc, char **argv) { allow_1st_noc_hop = test_args::has_command_option(input_args, "--allow_1st_noc_hop"); bidirectional_traffic = test_args::has_command_option(input_args, "--bidirectional"); - uint32_t tx_signal_address = default_tx_signal_address; + tx_signal_address = default_tx_signal_address; + host_signal_address = default_host_signal_address; std::string board_type = test_args::get_command_option(input_args, "--board_type", std::string(default_board_type)); @@ -1272,6 +1454,14 @@ int main(int argc, char **argv) { uint32_t packet_size_kb = test_args::get_command_option_uint32(input_args, "--packet_size_kb", default_packet_size_kb); uint32_t max_packet_size_words = packet_size_kb * 1024 / PACKET_WORD_SIZE_BYTES; + // Only supports line mcast from neighbour + if (mcast && num_hops != default_num_hops && num_hops != 1) { + throw std::runtime_error("Only line mcast is supported right now"); + } + + if (mcast && bidirectional_traffic) { + throw std::runtime_error("Bidirectional traffic is not supported for mcast"); + } bool pass = true; uint32_t num_available_devices, num_allocated_devices = 0; @@ -1305,6 +1495,7 @@ int main(int argc, char **argv) { } global_rng.seed(prng_seed); + log_info(LogTest, "PRNG seed = {}", prng_seed); time_seed = std::chrono::system_clock::now().time_since_epoch().count(); @@ -1334,12 +1525,22 @@ int main(int argc, char **argv) { // if both left and right device IDs are specified, launch traffic only b/w them if ((default_test_device_id_l != test_device_id_l) && (default_test_device_id_r != test_device_id_r)) { - if (test_device_id_l == test_device_id_r) { - throw std::runtime_error("Left and right chips should be different"); + if (mcast) { + // TODO: We require mcast origin to be the neighbor for now + // So get the path from test_device_id_l and verify the next chip is test_device_id_r + auto physical_mcast_chip_ids = test_board.get_physical_mcast_chip_ids(test_device_id_l, mcast_depth); + if (physical_mcast_chip_ids.empty() || physical_mcast_chip_ids[0] != test_device_id_r) { + throw std::runtime_error("No multicast path found"); + } + test_board.tx_rx_map.push_back({test_device_id_l, std::move(physical_mcast_chip_ids)}); + } else { + if (test_device_id_l == test_device_id_r) { + throw std::runtime_error("Left and right chips should be different"); + } + test_board.tx_rx_map.push_back({test_device_id_l, {test_device_id_r}}); } - test_board.unicast_map.push_back({test_device_id_l, test_device_id_r}); } else { - test_board.generate_unicast_map(num_hops); + test_board.generate_tx_rx_map(num_hops, mcast, mcast_depth); } std::unordered_map> test_devices; @@ -1352,14 +1553,20 @@ int main(int argc, char **argv) { // init traffic chip_id_t tx_chip_id, rx_chip_id; - for (auto& [tx_chip_id, rx_chip_id] : test_board.unicast_map) { + for (auto& [tx_chip_id, rx_chip_ids] : test_board.tx_rx_map) { if (num_allocated_devices >= num_traffic_devices) { break; } + std::vector> rx_devices; + rx_devices.reserve(rx_chip_ids.size()); + for (auto& rx_chip_id : rx_chip_ids) { + rx_devices.push_back(test_devices[rx_chip_id]); + } + test_traffic_t traffic( test_devices[tx_chip_id], - test_devices[rx_chip_id], + rx_devices, num_src_endpoints, num_dest_endpoints, target_address, @@ -1368,9 +1575,10 @@ int main(int argc, char **argv) { fabric_traffic.push_back(traffic); if (bidirectional_traffic) { + std::vector> rx_devices = {test_devices[tx_chip_id]}; test_traffic_t traffic_r( - test_devices[rx_chip_id], - test_devices[tx_chip_id], + test_devices[rx_chip_ids[0]], + rx_devices, num_src_endpoints, num_dest_endpoints, target_address, @@ -1379,7 +1587,7 @@ int main(int argc, char **argv) { fabric_traffic.push_back(traffic_r); } - num_allocated_devices += 2; + num_allocated_devices += 1 + rx_chip_ids.size(); } // TODO: check this in a loop for all the devices involved in the traffic @@ -1460,6 +1668,11 @@ int main(int argc, char **argv) { client_interface_addr, // 20: client_pull_req_buf_addr, // 21: fixed_async_wr_notif_addr, // 22: use fixed addr for async wr atomic inc + mcast, // 23: mcast + mcast_depth[RoutingDirection::E], // 24: mcast_e + mcast_depth[RoutingDirection::W], // 25: mcast_w + mcast_depth[RoutingDirection::N], // 26: mcast_n + mcast_depth[RoutingDirection::S], // 27: mcast_s }; std::vector rx_compile_args = { @@ -1478,8 +1691,7 @@ int main(int argc, char **argv) { // TODO: launch traffic kernels for (auto& traffic : fabric_traffic) { - traffic.create_kernels( - tx_compile_args, rx_compile_args, defines, fabric_command, tx_signal_address, test_results_addr); + traffic.create_kernels(tx_compile_args, rx_compile_args, defines, fabric_command, test_results_addr); } if (check_txrx_timeout) { @@ -1498,16 +1710,15 @@ int main(int argc, char **argv) { test_device->wait_for_gatekeeper_sync(); } - // notify tx kernels to start transmitting + // notify tx controller to signal the tx workers for (auto& traffic : fabric_traffic) { - traffic.notify_tx_workers(tx_signal_address); + traffic.notify_tx_controller(); } // wait for rx kernels to finish for (auto& traffic : fabric_traffic) { traffic.wait_for_rx_workers_to_finish(); } - // terminate fabric routers for (auto& [chip_id, test_device] : test_devices) { test_device->terminate_gatekeeper_kernel(); diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp index 5db2626708b..18d6a0f3ed3 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp @@ -13,7 +13,7 @@ void MAIN { #ifndef BCAST_OP_INIT init_bcast(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); #else - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); BCAST_OP_INIT(tt::CBIndex::c_0, tt::CBIndex::c_1); #endif diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp index c417f6b247d..c9e10192782 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/layernorm.cpp @@ -27,9 +27,9 @@ void MAIN { constexpr uint32_t do_beta = get_compile_time_arg_val(3); #ifdef FUSE_PRE_ADD - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); #else - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); #endif constexpr uint32_t onetile = 1; @@ -71,7 +71,7 @@ void MAIN { * X + Y */ #ifdef FUSE_PRE_ADD - add_tiles_init(); + add_tiles_init(cb_in, cb_inb); for (uint32_t wt = 0; wt < Wt; wt += blk) { ACQ(); // UNPACK(( { DPRINT << "Waiting on cb_x" << ENDL(); } )); @@ -134,7 +134,7 @@ void MAIN { /* (x - E[x])^2 * compute temp = xmm*xmm = (x-E[x])^2 */ - mul_tiles_init(); + mul_tiles_init(cb_xmm, cb_xmm); for (uint32_t wt = 0; wt < Wt; wt += blk) { cb_wait_front(cb_xmm, wt + blk); // cumulative wait cb_reserve_back(cb_xmm2, blk); // can probably use less space for this if we block @@ -177,7 +177,7 @@ void MAIN { * add epsilon E[(x-E[x])^2]+eps */ ACQ(); - add_tiles_init(); + add_tiles_init(cb_ex2, cb_eps); add_tiles(cb_ex2, cb_eps, 0, 0, dst0); cb_reserve_back(cb_ex2pe, 1); // 1 diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp index d189d0e1a3a..209ec0762ff 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/rmsnorm.cpp @@ -27,9 +27,9 @@ void MAIN { constexpr uint32_t do_beta = get_compile_time_arg_val(3); #ifdef FUSE_PRE_ADD - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); #else - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); #endif constexpr uint32_t onetile = 1; @@ -69,7 +69,7 @@ void MAIN { * X + Y */ #ifdef FUSE_PRE_ADD - add_tiles_init(); + add_tiles_init(cb_in, cb_inb); for (uint32_t wt = 0; wt < Wt; wt += blk) { ACQ(); // UNPACK(( { DPRINT << "Waiting on cb_x" << ENDL(); } )); @@ -93,7 +93,7 @@ void MAIN { /* (x)^2 * compute temp = x^2 */ - mul_tiles_init(); + mul_tiles_init(cb_x, cb_x); for (uint32_t wt = 0; wt < Wt; wt += blk) { cb_wait_front(cb_x, wt + blk); cb_reserve_back(cb_x2, blk); // can probably use less space for this if we block @@ -134,7 +134,7 @@ void MAIN { * add epsilon E[(x-E[x])^2]+eps */ ACQ(); - add_tiles_init(); + add_tiles_init(cb_ex2, cb_eps); add_tiles(cb_ex2, cb_eps, 0, 0, dst0); cb_reserve_back(cb_ex2pe, 1); // 1 diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/rotary_embedding.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/rotary_embedding.cpp index 52dfead0dd1..39ed9fe9ede 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/rotary_embedding.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/rotary_embedding.cpp @@ -30,7 +30,7 @@ ALWI void MUL_TILES(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t // We don't pop in1 in decode which is sin/cos since we don't stream #else ACQ(); - mul_tiles_init(); + mul_tiles_init(in0_cb, in1_cb); mul_tiles(in0_cb, in1_cb, 0, 0, 0); pack_tile(0, out_cb); REL(); @@ -80,7 +80,7 @@ void MAIN { constexpr uint32_t Wt = get_compile_time_arg_val(10); constexpr uint32_t half_Wt = get_compile_time_arg_val(11); - binary_op_init_common(in_cb, cos_cb); + binary_op_init_common(in_cb, cos_cb, out_cb); cb_wait_front(scalar_cb, onetile); @@ -134,7 +134,7 @@ void MAIN { cb_reserve_back(out_cb, onetile); ACQ(); - add_tiles_init(); + add_tiles_init(cos_interm_cb, sin_interm_cb); add_tiles(cos_interm_cb, sin_interm_cb, 0, 0, 0); pack_tile(0, out_cb); REL(); diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp index 8a806c72d73..ab2d2ebe8e5 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/softmax.cpp @@ -30,7 +30,7 @@ void MAIN { const uint32_t Wt = get_arg_val(2); const uint32_t ndst = get_arg_val(3); const uint32_t start_ht = get_arg_val(4); - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_2); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_2, tt::CBIndex::c_16); constexpr uint32_t onetile = 1; // reserve one tile for zeros on cb_in2 diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/unary_bcast.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/unary_bcast.cpp new file mode 100644 index 00000000000..1dfdc05f23a --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/compute/unary_bcast.cpp @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "compute_kernel_api/bcast.h" +#include "compute_kernel_api/eltwise_binary.h" + +namespace NAMESPACE { +void MAIN { + uint32_t per_core_block_cnt = get_compile_time_arg_val(0); + uint32_t per_core_block_dim = get_compile_time_arg_val(1); + + unary_bcast_init(tt::CBIndex::c_0, tt::CBIndex::c_16); + + for (uint32_t block_index = 0; block_index < per_core_block_cnt; block_index++) { + cb_wait_front(tt::CBIndex::c_0, per_core_block_dim); + acquire_dst(); + for (uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) { + unary_bcast(tt::CBIndex::c_0, tile_index, tile_index); + } + + cb_pop_front(tt::CBIndex::c_0, per_core_block_dim); + cb_reserve_back(tt::CBIndex::c_16, per_core_block_dim); + + for (uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) { + pack_tile(tile_index, tt::CBIndex::c_16); + } + + cb_push_back(tt::CBIndex::c_16, per_core_block_dim); + release_dst(); + } + + reconfigure_unary_bcast( + tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16, tt::CBIndex::c_17); + + for (uint32_t block_index = 0; block_index < per_core_block_cnt; block_index++) { + cb_wait_front(tt::CBIndex::c_1, per_core_block_dim); + acquire_dst(); + for (uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) { + unary_bcast(tt::CBIndex::c_1, tile_index, tile_index); + } + + cb_pop_front(tt::CBIndex::c_1, per_core_block_dim); + cb_reserve_back(tt::CBIndex::c_17, per_core_block_dim); + + for (uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) { + pack_tile(tile_index, tt::CBIndex::c_17); + } + + cb_push_back(tt::CBIndex::c_17, per_core_block_dim); + release_dst(); + } +} +} // namespace NAMESPACE diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/reader_dual_unary.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/reader_dual_unary.cpp new file mode 100644 index 00000000000..97f3adb6cd7 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/reader_dual_unary.cpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +inline void read_tiles(uint32_t num_tiles, uint32_t src_addr, uint32_t bank_id, uint32_t cb_id_in) { + // ublocks size defined in tiles + constexpr uint32_t ublock_size_tiles = 1; + uint32_t ublock_size_bytes = get_tile_size(cb_id_in) * ublock_size_tiles; + + // read a ublock of tiles from src to CB, and then push the ublock to unpacker + for (uint32_t i = 0; i < num_tiles; i += ublock_size_tiles) { + uint64_t src_noc_addr = get_noc_addr_from_bank_id(bank_id, src_addr); + + cb_reserve_back(cb_id_in, ublock_size_tiles); + uint32_t l1_write_addr = get_write_ptr(cb_id_in); + noc_async_read(src_noc_addr, l1_write_addr, ublock_size_bytes); + + noc_async_read_barrier(); + + cb_push_back(cb_id_in, ublock_size_tiles); + src_addr += ublock_size_bytes; + } +} + +void kernel_main() { + uint32_t src_addr_0 = get_arg_val(0); + uint32_t bank_id_0 = get_arg_val(1); + uint32_t src_addr_1 = get_arg_val(2); + uint32_t bank_id_1 = get_arg_val(3); + uint32_t num_tiles = get_arg_val(4); + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + + read_tiles(num_tiles, src_addr_0, bank_id_0, cb_id_in0); + read_tiles(num_tiles, src_addr_1, bank_id_1, cb_id_in1); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/eth_non_blocking_receive_fwd_to_dram.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/eth_non_blocking_receive_fwd_to_dram.cpp index ffdd827ed41..d222ae567a5 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/eth_non_blocking_receive_fwd_to_dram.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/eth_non_blocking_receive_fwd_to_dram.cpp @@ -18,7 +18,7 @@ // Initiate DRAM write -> advances write pointer template -void write_chunk( +void write_chunk_legacy( const uint32_t eth_l1_buffer_address_base, const uint32_t num_pages, const uint32_t num_pages_per_l1_buffer, @@ -62,7 +62,7 @@ bool eth_initiate_noc_write_sequence( // and the receiver ackptr != next write pointer // // DPRINT << "rx: accepting payload, sending receive ack on channel " << // (uint32_t)noc_writer_buffer_wrptr << "\n"; - write_chunk( + write_chunk_legacy( transaction_channel_receiver_buffer_addresses[noc_writer_buffer_wrptr.index()], num_pages, num_pages_per_l1_buffer, diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_common.hpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_common.hpp index 23826835c81..34825404d9a 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_common.hpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_common.hpp @@ -40,3 +40,229 @@ template bool is_power_of_two(T val) { return (val & (val - 1)) == T(0); } + +// ******************************* Common Ct Args ************************************************ + +constexpr uint32_t NUM_BUFFER_SLOTS = get_compile_time_arg_val(0); +constexpr uint32_t MAX_NUM_TRANSACTION_ID = + NUM_BUFFER_SLOTS / 2; // the algorithm only works for NUM_BUFFER_SLOTS divisible by MAX_NUM_TRANSACTION_ID +constexpr uint32_t worker_noc_x = get_compile_time_arg_val(1); +constexpr uint32_t worker_noc_y = get_compile_time_arg_val(2); +constexpr uint32_t worker_buffer_addr = get_compile_time_arg_val(3); + +// ******************************* Sender APIs *************************************************** + +FORCE_INLINE uint32_t setup_sender_buffer( + std::array& buffer_slot_addrs, + std::array& buffer_slot_sync_addrs, + uint32_t buffer_slot_addr, + uint32_t message_size) { + for (uint8_t i = 0; i < NUM_BUFFER_SLOTS; i++) { + buffer_slot_addrs[i] = buffer_slot_addr; + buffer_slot_addr += message_size; + buffer_slot_sync_addrs[i] = reinterpret_cast(buffer_slot_addr); + buffer_slot_addr += sizeof(eth_buffer_slot_sync_t); + } + + // reset bytes_sent to 1s so first iter it will block on receiver ack + for (uint32_t i = 0; i < NUM_BUFFER_SLOTS; i++) { + buffer_slot_sync_addrs[i]->bytes_sent = 1; + } + + // assemble a packet filled with values + for (uint32_t i = 0; i < NUM_BUFFER_SLOTS; i++) { + tt_l1_ptr uint8_t* ptr = reinterpret_cast(buffer_slot_addrs[i]); + for (uint32_t j = 0; j < message_size; j++) { + ptr[j] = j; + } + } + + uint32_t buffer_end_addr = buffer_slot_addr; + return buffer_end_addr; +} + +FORCE_INLINE uint32_t advance_buffer_slot_ptr(uint32_t curr_ptr) { return (curr_ptr + 1) % NUM_BUFFER_SLOTS; } + +FORCE_INLINE void write_receiver( + uint32_t buffer_slot_addr, volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr, uint32_t full_payload_size) { + buffer_slot_sync_addr->bytes_sent = 1; + + while (eth_txq_is_busy()) { + switch_context_if_debug(); + } + + eth_send_bytes_over_channel_payload_only_unsafe_one_packet(buffer_slot_addr, buffer_slot_addr, full_payload_size); +} + +FORCE_INLINE bool has_receiver_ack(volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr) { + return buffer_slot_sync_addr->bytes_sent == 0; +} + +FORCE_INLINE void check_buffer_full_and_send_packet( + const std::array& buffer_slot_addrs, + const std::array& buffer_slot_sync_addrs, + uint32_t read_ptr, + uint32_t& write_ptr, + uint64_t full_payload_size, + uint32_t& num_messages_send) { + uint32_t next_write_ptr = advance_buffer_slot_ptr(write_ptr); + bool buffer_not_full = next_write_ptr != read_ptr; + + if (buffer_not_full && num_messages_send != 0) { + write_receiver(buffer_slot_addrs[write_ptr], buffer_slot_sync_addrs[write_ptr], full_payload_size); + + write_ptr = next_write_ptr; + num_messages_send--; + } +} + +FORCE_INLINE void check_receiver_done( + const std::array& buffer_slot_sync_addrs, + uint32_t& read_ptr, + uint32_t& num_messages_ack) { + if (has_receiver_ack(buffer_slot_sync_addrs[read_ptr])) { + uint32_t next_read_ptr = advance_buffer_slot_ptr(read_ptr); + + buffer_slot_sync_addrs[read_ptr]->bytes_sent = 1; + read_ptr = next_read_ptr; + num_messages_ack++; + } +} + +FORCE_INLINE void update_sender_state( + const std::array& buffer_slot_addrs, + const std::array& buffer_slot_sync_addrs, + uint32_t full_payload_size, + uint32_t& num_messages_ack, + uint32_t& num_messages_send, + uint32_t& buffer_read_ptr, + uint32_t& buffer_write_ptr) { + // Check if current buffer slot is ready and send packet to receiver + check_buffer_full_and_send_packet( + buffer_slot_addrs, + buffer_slot_sync_addrs, + buffer_read_ptr, + buffer_write_ptr, + full_payload_size, + num_messages_send); + // Check if the write for trid is done, and ack sender if the current buffer slot is done + check_receiver_done(buffer_slot_sync_addrs, buffer_read_ptr, num_messages_ack); +} + +// ******************************* Receiver APIs ************************************************* + +FORCE_INLINE uint32_t setup_receiver_buffer( + std::array& buffer_slot_addrs, + std::array& buffer_slot_sync_addrs, + uint32_t buffer_slot_addr, + uint32_t message_size) { + for (uint8_t i = 0; i < NUM_BUFFER_SLOTS; i++) { + buffer_slot_addrs[i] = buffer_slot_addr; + buffer_slot_addr += message_size; + buffer_slot_sync_addrs[i] = reinterpret_cast(buffer_slot_addr); + buffer_slot_sync_addrs[i]->bytes_sent = 0; + buffer_slot_sync_addrs[i]->receiver_ack = 0; + buffer_slot_addr += sizeof(eth_buffer_slot_sync_t); + } + + uint32_t buffer_end_addr = buffer_slot_addr; + return buffer_end_addr; +} + +FORCE_INLINE uint32_t get_buffer_slot_trid(uint32_t curr_ptr) { return curr_ptr % MAX_NUM_TRANSACTION_ID + 1; } + +FORCE_INLINE bool has_incoming_packet(volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr) { + return buffer_slot_sync_addr->bytes_sent != 0; +} + +FORCE_INLINE bool write_worker_done(uint32_t trid) { + return ncrisc_noc_nonposted_write_with_transaction_id_flushed(noc_index, trid); +} + +FORCE_INLINE void ack_complete(volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr) { + buffer_slot_sync_addr->bytes_sent = 0; + + while (eth_txq_is_busy()) { + switch_context_if_debug(); + } + + eth_send_bytes_over_channel_payload_only_unsafe_one_packet( + reinterpret_cast(buffer_slot_sync_addr), + reinterpret_cast(buffer_slot_sync_addr), + sizeof(eth_buffer_slot_sync_t)); +} + +FORCE_INLINE void write_worker( + uint32_t buffer_slot_addr, + volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr, + uint64_t worker_noc_addr, + uint32_t message_size, + uint32_t curr_trid_to_write) { + // write to local +#ifdef DISABLE_TRID + noc_async_write(buffer_slot_addr, worker_noc_addr, message_size); + noc_async_writes_flushed(); +#else + noc_async_write_one_packet_with_trid_with_state( + buffer_slot_addr, worker_noc_addr, message_size, curr_trid_to_write); +#endif + // reset sync + buffer_slot_sync_addr->bytes_sent = 0; +} + +FORCE_INLINE void check_incomping_packet_and_write_worker( + const std::array& buffer_slot_addrs, + const std::array& buffer_slot_sync_addrs, + uint32_t read_ptr, + uint32_t& write_ptr, + uint64_t worker_noc_addr, + uint32_t message_size) { + uint32_t next_write_ptr = advance_buffer_slot_ptr(write_ptr); + bool buffer_not_full = next_write_ptr != read_ptr; + + if (buffer_not_full && has_incoming_packet(buffer_slot_sync_addrs[write_ptr])) { +#ifdef ENABLE_WORKER + uint32_t curr_trid = get_buffer_slot_trid(write_ptr); + write_worker( + buffer_slot_addrs[write_ptr], buffer_slot_sync_addrs[write_ptr], worker_noc_addr, message_size, curr_trid); +#endif + write_ptr = next_write_ptr; + } +} + +FORCE_INLINE void check_write_worker_done_and_send_ack( + const std::array& buffer_slot_sync_addrs, + uint32_t& read_ptr, + uint32_t write_ptr, + uint32_t& num_messages_ack) { + bool buffer_not_empty = read_ptr != write_ptr; + +#if defined(ENABLE_WORKER) and !defined(DISABLE_TRID) + uint32_t curr_trid = get_buffer_slot_trid(read_ptr); + if (buffer_not_empty && write_worker_done(curr_trid)) { +#else + if (buffer_not_empty) { +#endif + // DPRINT << "read_ptr " << read_ptr <& buffer_slot_addrs, + const std::array& buffer_slot_sync_addrs, + uint64_t worker_noc_addr, + uint32_t message_size, + uint32_t& num_messages_ack, + uint32_t& buffer_read_ptr, + uint32_t& buffer_write_ptr) { + // Check if there's an incoming packet for current buffer slot and write to worker if there's new packet + check_incomping_packet_and_write_worker( + buffer_slot_addrs, buffer_slot_sync_addrs, buffer_read_ptr, buffer_write_ptr, worker_noc_addr, message_size); + // Check if the write for trid is done, and ack sender if the current buffer slot is done + check_write_worker_done_and_send_ack(buffer_slot_sync_addrs, buffer_read_ptr, buffer_write_ptr, num_messages_ack); +} diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_receiver.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_receiver.cpp index 5f241b1b48d..dc11308f5bb 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_receiver.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_receiver.cpp @@ -4,111 +4,94 @@ #include "ethernet_write_worker_latency_ubench_common.hpp" -static constexpr uint32_t NUM_BUFFER_SLOTS = get_compile_time_arg_val(0); -static constexpr uint32_t MAX_NUM_TRANSACTION_ID = - NUM_BUFFER_SLOTS / 2; // the algorithm only works for NUM_BUFFER_SLOTS divisible by MAX_NUM_TRANSACTION_ID -static constexpr uint32_t worker_noc_x = get_compile_time_arg_val(1); -static constexpr uint32_t worker_noc_y = get_compile_time_arg_val(2); -static constexpr uint32_t worker_buffer_addr = get_compile_time_arg_val(3); - -FORCE_INLINE uint32_t advance_buffer_slot_ptr(uint32_t curr_ptr) { return (curr_ptr + 1) % NUM_BUFFER_SLOTS; } - -FORCE_INLINE uint32_t get_buffer_slot_trid(uint32_t curr_ptr) { return curr_ptr % MAX_NUM_TRANSACTION_ID + 1; } - -FORCE_INLINE bool has_incoming_packet(volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr) { - return buffer_slot_sync_addr->bytes_sent != 0; -} - -FORCE_INLINE bool write_worker_done(uint32_t trid) { - return ncrisc_noc_nonposted_write_with_transaction_id_flushed(noc_index, trid); -} - -FORCE_INLINE void ack_complete(volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr) { - buffer_slot_sync_addr->bytes_sent = 0; - - eth_send_bytes_over_channel_payload_only_unsafe( - reinterpret_cast(buffer_slot_sync_addr), - reinterpret_cast(buffer_slot_sync_addr), - sizeof(eth_buffer_slot_sync_t), - sizeof(eth_buffer_slot_sync_t), - sizeof(eth_buffer_slot_sync_t) >> 4); -} - -FORCE_INLINE void write_worker( - uint32_t buffer_slot_addr, - volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr, - uint64_t worker_noc_addr, +FORCE_INLINE void main_loop_uni_dir( + const std::array& receiver_buffer_slot_addrs, + const std::array& receiver_buffer_slot_sync_addrs, uint32_t message_size, - uint32_t curr_trid_to_write) { - // write to local - noc_async_write_one_packet_with_trid(buffer_slot_addr, worker_noc_addr, message_size, curr_trid_to_write); + uint32_t num_messages, + uint64_t worker_noc_addr) { + uint32_t total_msgs = +#ifdef TEST_LATENCY + num_messages; +#else + num_messages * NUM_BUFFER_SLOTS; +#endif - // reset sync - buffer_slot_sync_addr->bytes_sent = 0; -} - -FORCE_INLINE void check_incomping_packet_and_write_worker( - const std::array& buffer_slot_addrs, - const std::array& buffer_slot_sync_addrs, - uint32_t read_ptr, - uint32_t& write_ptr, - uint64_t worker_noc_addr, - uint32_t message_size) { - uint32_t next_write_ptr = advance_buffer_slot_ptr(write_ptr); - bool buffer_not_full = next_write_ptr != read_ptr; - - if (buffer_not_full && has_incoming_packet(buffer_slot_sync_addrs[write_ptr])) { - uint32_t curr_trid = get_buffer_slot_trid(write_ptr); - write_worker( - buffer_slot_addrs[write_ptr], buffer_slot_sync_addrs[write_ptr], worker_noc_addr, message_size, curr_trid); - - write_ptr = next_write_ptr; - } -} + DPRINT << "RECEIVER MAIN LOOP" << ENDL(); -FORCE_INLINE void check_write_worker_done_and_send_ack( - const std::array& buffer_slot_sync_addrs, - uint32_t& read_ptr, - uint32_t write_ptr, - uint32_t& num_messages_ack) { - bool buffer_not_empty = read_ptr != write_ptr; - uint32_t curr_trid = get_buffer_slot_trid(read_ptr); + uint32_t receiver_buffer_read_ptr = 0; + uint32_t receiver_buffer_write_ptr = 0; + uint32_t receiver_num_messages_ack = 0; - if (buffer_not_empty && write_worker_done(curr_trid) && !eth_txq_is_busy()) { - ack_complete(buffer_slot_sync_addrs[read_ptr]); + noc_async_write_one_packet_with_trid_set_state(worker_noc_addr); - read_ptr = advance_buffer_slot_ptr(read_ptr); + while (receiver_num_messages_ack < total_msgs) { + update_receiver_state( + receiver_buffer_slot_addrs, + receiver_buffer_slot_sync_addrs, + worker_noc_addr, + message_size, + receiver_num_messages_ack, + receiver_buffer_read_ptr, + receiver_buffer_write_ptr); - num_messages_ack++; + // not called in normal execution mode + switch_context_if_debug(); } } -FORCE_INLINE void receiver_main_loop( - const std::array& buffer_slot_addrs, - const std::array& buffer_slot_sync_addrs, - uint64_t worker_noc_addr, +FORCE_INLINE void main_loop_bi_dir( + const std::array& sender_buffer_slot_addrs, + const std::array& sender_buffer_slot_sync_addrs, + const std::array& receiver_buffer_slot_addrs, + const std::array& receiver_buffer_slot_sync_addrs, + uint32_t full_payload_size, uint32_t message_size, - uint32_t num_messages) { - uint32_t total_msgs = num_messages * NUM_BUFFER_SLOTS; + uint32_t num_messages, + uint64_t worker_noc_addr) { + uint32_t total_msgs = +#ifdef TEST_LATENCY + num_messages * 2; +#else + num_messages * NUM_BUFFER_SLOTS * 2; +#endif DPRINT << "RECEIVER MAIN LOOP" << ENDL(); - uint32_t buffer_read_ptr = 0; - uint32_t buffer_write_ptr = 0; + uint32_t sender_buffer_read_ptr = 0; + uint32_t sender_buffer_write_ptr = 0; + + uint32_t receiver_buffer_read_ptr = 0; + uint32_t receiver_buffer_write_ptr = 0; uint32_t num_messages_ack = 0; + uint32_t sender_num_messages_send = +#ifdef TEST_LATENCY + num_messages; +#else + num_messages * NUM_BUFFER_SLOTS; +#endif + + noc_async_write_one_packet_with_trid_set_state(worker_noc_addr); + while (num_messages_ack < total_msgs) { - // Check if there's an incoming packet for current buffer slot and write to worker if there's new packet - check_incomping_packet_and_write_worker( - buffer_slot_addrs, - buffer_slot_sync_addrs, - buffer_read_ptr, - buffer_write_ptr, + update_sender_state( + sender_buffer_slot_addrs, + sender_buffer_slot_sync_addrs, + full_payload_size, + num_messages_ack, + sender_num_messages_send, + sender_buffer_read_ptr, + sender_buffer_write_ptr); + + update_receiver_state( + receiver_buffer_slot_addrs, + receiver_buffer_slot_sync_addrs, worker_noc_addr, - message_size); - // Check if the write for trid is done, and ack sender if the current buffer slot is done - check_write_worker_done_and_send_ack( - buffer_slot_sync_addrs, buffer_read_ptr, buffer_write_ptr, num_messages_ack); + message_size, + num_messages_ack, + receiver_buffer_read_ptr, + receiver_buffer_write_ptr); // not called in normal execution mode switch_context_if_debug(); @@ -123,19 +106,21 @@ void kernel_main() { ASSERT(is_power_of_two(NUM_BUFFER_SLOTS)); - std::array buffer_slot_addrs; - std::array buffer_slot_sync_addrs; - { - uint32_t buffer_slot_addr = handshake_addr + sizeof(eth_buffer_slot_sync_t); - for (uint8_t i = 0; i < NUM_BUFFER_SLOTS; i++) { - buffer_slot_addrs[i] = buffer_slot_addr; - buffer_slot_addr += message_size; - buffer_slot_sync_addrs[i] = reinterpret_cast(buffer_slot_addr); - buffer_slot_sync_addrs[i]->bytes_sent = 0; - buffer_slot_sync_addrs[i]->receiver_ack = 0; - buffer_slot_addr += sizeof(eth_buffer_slot_sync_t); - } - } + const uint32_t full_payload_size = message_size + sizeof(eth_buffer_slot_sync_t); + const uint32_t full_payload_size_eth_words = full_payload_size >> 4; + + uint32_t buffer_start_addr = handshake_addr + sizeof(eth_buffer_slot_sync_t); + + std::array receiver_buffer_slot_addrs; + std::array receiver_buffer_slot_sync_addrs; + buffer_start_addr = setup_receiver_buffer( + receiver_buffer_slot_addrs, receiver_buffer_slot_sync_addrs, buffer_start_addr, message_size); + +#ifdef ENABLE_BI_DIRECTION + std::array sender_buffer_slot_addrs; + std::array sender_buffer_slot_sync_addrs; + setup_sender_buffer(sender_buffer_slot_addrs, sender_buffer_slot_sync_addrs, buffer_start_addr, message_size); +#endif // Avoids hang in issue https://github.com/tenstorrent/tt-metal/issues/9963 for (uint32_t i = 0; i < 2000000000; i++) { @@ -149,6 +134,24 @@ void kernel_main() { { DeviceZoneScopedN("MAIN-TEST-BODY"); - receiver_main_loop(buffer_slot_addrs, buffer_slot_sync_addrs, worker_noc_addr, message_size, num_messages); +#ifdef ENABLE_BI_DIRECTION + main_loop_bi_dir( + sender_buffer_slot_addrs, + sender_buffer_slot_sync_addrs, + receiver_buffer_slot_addrs, + receiver_buffer_slot_sync_addrs, + full_payload_size, + message_size, + num_messages, + worker_noc_addr); +#else + main_loop_uni_dir( + receiver_buffer_slot_addrs, receiver_buffer_slot_sync_addrs, message_size, num_messages, worker_noc_addr); +#endif + } + // need to do a delay as trid writes are not waiting for acks, so need to make sure noc response is back. + for (int i = 0; i < 1000; ++i) { + asm volatile("nop"); } + ncrisc_noc_counters_init(); } diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_sender.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_sender.cpp index cdf37185e7a..799df166e6d 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_sender.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/erisc/ethernet_write_worker_latency_ubench_sender.cpp @@ -4,81 +4,92 @@ #include "ethernet_write_worker_latency_ubench_common.hpp" -static constexpr uint32_t NUM_BUFFER_SLOTS = get_compile_time_arg_val(0); - -FORCE_INLINE uint32_t advance_buffer_slot_ptr(uint32_t curr_ptr) { return (curr_ptr + 1) % NUM_BUFFER_SLOTS; } - -FORCE_INLINE void write_receiver( - uint32_t buffer_slot_addr, - volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr, +FORCE_INLINE void main_loop_uni_dir( + const std::array& buffer_slot_addrs, + const std::array& buffer_slot_sync_addrs, uint32_t full_payload_size, - uint32_t full_payload_size_eth_words) { - buffer_slot_sync_addr->bytes_sent = 1; + uint32_t num_messages) { + uint32_t total_msgs = +#ifdef TEST_LATENCY + num_messages; +#else + num_messages * NUM_BUFFER_SLOTS; +#endif - eth_send_bytes_over_channel_payload_only_unsafe( - buffer_slot_addr, buffer_slot_addr, full_payload_size, full_payload_size, full_payload_size_eth_words); -} + DPRINT << "SENDER MAIN LOOP" << ENDL(); -FORCE_INLINE bool has_receiver_ack(volatile eth_buffer_slot_sync_t* buffer_slot_sync_addr) { - return buffer_slot_sync_addr->bytes_sent == 0; -} + uint32_t sender_buffer_read_ptr = 0; + uint32_t sender_buffer_write_ptr = 0; + uint32_t sender_num_messages_ack = 0; + uint32_t sender_num_messages_send = total_msgs; -FORCE_INLINE void check_buffer_full_and_send_packet( - const std::array& buffer_slot_addrs, - const std::array& buffer_slot_sync_addrs, - uint32_t read_ptr, - uint32_t& write_ptr, - uint64_t full_payload_size, - uint32_t full_payload_size_eth_words) { - uint32_t next_write_ptr = advance_buffer_slot_ptr(write_ptr); - bool buffer_not_full = next_write_ptr != read_ptr; - - if (buffer_not_full && !eth_txq_is_busy()) { - write_receiver( - buffer_slot_addrs[write_ptr], - buffer_slot_sync_addrs[write_ptr], + while (sender_num_messages_ack < total_msgs) { + update_sender_state( + buffer_slot_addrs, + buffer_slot_sync_addrs, full_payload_size, - full_payload_size_eth_words); + sender_num_messages_ack, + sender_num_messages_send, + sender_buffer_read_ptr, + sender_buffer_write_ptr); - write_ptr = next_write_ptr; - } -} - -FORCE_INLINE void check_receiver_done( - const std::array& buffer_slot_sync_addrs, - uint32_t& read_ptr, - uint32_t& num_messages_ack) { - if (has_receiver_ack(buffer_slot_sync_addrs[read_ptr])) { - read_ptr = advance_buffer_slot_ptr(read_ptr); - num_messages_ack++; + // not called in normal execution mode + switch_context_if_debug(); } } -FORCE_INLINE void sender_main_loop( - const std::array& buffer_slot_addrs, - const std::array& buffer_slot_sync_addrs, +FORCE_INLINE void main_loop_bi_dir( + const std::array& sender_buffer_slot_addrs, + const std::array& sender_buffer_slot_sync_addrs, + const std::array& receiver_buffer_slot_addrs, + const std::array& receiver_buffer_slot_sync_addrs, uint32_t full_payload_size, - uint32_t num_messages) { - uint32_t full_payload_size_eth_words = full_payload_size >> 4; - uint32_t total_msgs = num_messages * NUM_BUFFER_SLOTS; + uint32_t message_size, + uint32_t num_messages, + uint64_t worker_noc_addr) { + uint32_t total_msgs = +#ifdef TEST_LATENCY + num_messages * 2; +#else + num_messages * NUM_BUFFER_SLOTS * 2; +#endif DPRINT << "SENDER MAIN LOOP" << ENDL(); - uint32_t buffer_read_ptr = 0; - uint32_t buffer_write_ptr = 0; + uint32_t sender_buffer_read_ptr = 0; + uint32_t sender_buffer_write_ptr = 0; + + uint32_t receiver_buffer_read_ptr = 0; + uint32_t receiver_buffer_write_ptr = 0; uint32_t num_messages_ack = 0; + uint32_t sender_num_messages_send = +#ifdef TEST_LATENCY + num_messages; +#else + num_messages * NUM_BUFFER_SLOTS; +#endif + + noc_async_write_one_packet_with_trid_set_state(worker_noc_addr); + while (num_messages_ack < total_msgs) { - // Check if current buffer slot is ready and send packet to receiver - check_buffer_full_and_send_packet( - buffer_slot_addrs, - buffer_slot_sync_addrs, - buffer_read_ptr, - buffer_write_ptr, + update_sender_state( + sender_buffer_slot_addrs, + sender_buffer_slot_sync_addrs, full_payload_size, - full_payload_size_eth_words); - // Check if the write for trid is done, and ack sender if the current buffer slot is done - check_receiver_done(buffer_slot_sync_addrs, buffer_read_ptr, num_messages_ack); + num_messages_ack, + sender_num_messages_send, + sender_buffer_read_ptr, + sender_buffer_write_ptr); + + update_receiver_state( + receiver_buffer_slot_addrs, + receiver_buffer_slot_sync_addrs, + worker_noc_addr, + message_size, + num_messages_ack, + receiver_buffer_read_ptr, + receiver_buffer_write_ptr); // not called in normal execution mode switch_context_if_debug(); @@ -90,39 +101,24 @@ void kernel_main() { const uint32_t handshake_addr = get_arg_val(arg_idx++); const uint32_t num_messages = get_arg_val(arg_idx++); const uint32_t message_size = get_arg_val(arg_idx++); - bool is_sender_offset_0 = get_arg_val(arg_idx++) == 1; ASSERT(is_power_of_two(NUM_BUFFER_SLOTS)); - const uint32_t message_size_eth_words = message_size >> 4; - const uint32_t full_payload_size = message_size + sizeof(eth_buffer_slot_sync_t); const uint32_t full_payload_size_eth_words = full_payload_size >> 4; - std::array buffer_slot_addrs; - std::array buffer_slot_sync_addrs; - { - uint32_t channel_addr = handshake_addr + sizeof(eth_buffer_slot_sync_t); - for (uint8_t i = 0; i < NUM_BUFFER_SLOTS; i++) { - buffer_slot_addrs[i] = channel_addr; - channel_addr += message_size; - buffer_slot_sync_addrs[i] = reinterpret_cast(channel_addr); - channel_addr += sizeof(eth_buffer_slot_sync_t); - } - } + uint32_t buffer_start_addr = handshake_addr + sizeof(eth_buffer_slot_sync_t); - // reset bytes_sent to 0s so first iter it won't block - for (uint32_t i = 0; i < NUM_BUFFER_SLOTS; i++) { - buffer_slot_sync_addrs[i]->bytes_sent = 0; - } + std::array sender_buffer_slot_addrs; + std::array sender_buffer_slot_sync_addrs; + buffer_start_addr = + setup_sender_buffer(sender_buffer_slot_addrs, sender_buffer_slot_sync_addrs, buffer_start_addr, message_size); - // assemble a packet filled with values - for (uint32_t i = 0; i < NUM_BUFFER_SLOTS; i++) { - tt_l1_ptr uint8_t* ptr = reinterpret_cast(buffer_slot_addrs[i]); - for (uint32_t j = 0; j < message_size; j++) { - ptr[j] = j; - } - } +#ifdef ENABLE_BI_DIRECTION + std::array receiver_buffer_slot_addrs; + std::array receiver_buffer_slot_sync_addrs; + setup_receiver_buffer(receiver_buffer_slot_addrs, receiver_buffer_slot_sync_addrs, buffer_start_addr, message_size); +#endif // Avoids hang in issue https://github.com/tenstorrent/tt-metal/issues/9963 for (uint32_t i = 0; i < 2000000000; i++) { @@ -130,8 +126,31 @@ void kernel_main() { } eth_setup_handshake(handshake_addr, true); + // worker noc address +#ifdef ENABLE_BI_DIRECTION + uint64_t worker_noc_addr = get_noc_addr(worker_noc_x, worker_noc_y, worker_buffer_addr); +#endif + { DeviceZoneScopedN("MAIN-TEST-BODY"); - sender_main_loop(buffer_slot_addrs, buffer_slot_sync_addrs, full_payload_size, num_messages); +#ifdef ENABLE_BI_DIRECTION + main_loop_bi_dir( + sender_buffer_slot_addrs, + sender_buffer_slot_sync_addrs, + receiver_buffer_slot_addrs, + receiver_buffer_slot_sync_addrs, + full_payload_size, + message_size, + num_messages, + worker_noc_addr); +#else + main_loop_uni_dir(sender_buffer_slot_addrs, sender_buffer_slot_sync_addrs, full_payload_size, num_messages); +#endif + } + + // need to do a delay as trid writes are not waiting for acks, so need to make sure noc response is back. + for (int i = 0; i < 1000; ++i) { + asm volatile("nop"); } + ncrisc_noc_counters_init(); } diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/writer_dual_unary.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/writer_dual_unary.cpp new file mode 100644 index 00000000000..167b3ae81e1 --- /dev/null +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/writer_dual_unary.cpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include "debug/dprint.h" + +inline void write_tiles(uint32_t num_tiles, uint32_t dst_addr, uint32_t bank_id, uint32_t cb_id_out) { + // single-tile ublocks + uint32_t ublock_size_tiles = 1; + uint32_t ublock_size_bytes = get_tile_size(cb_id_out); + + for (uint32_t i = 0; i < num_tiles; i += ublock_size_tiles) { + uint64_t dst_noc_addr = get_noc_addr_from_bank_id(bank_id, dst_addr); + + cb_wait_front(cb_id_out, ublock_size_tiles); + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + noc_async_write(l1_read_addr, dst_noc_addr, ublock_size_bytes); + + noc_async_write_barrier(); + + cb_pop_front(cb_id_out, ublock_size_tiles); + dst_addr += ublock_size_bytes; + } +} + +void kernel_main() { + uint32_t dst_addr_0 = get_arg_val(0); + uint32_t bank_id_0 = get_arg_val(1); + uint32_t dst_addr_1 = get_arg_val(2); + uint32_t bank_id_1 = get_arg_val(3); + uint32_t num_tiles = get_arg_val(4); + constexpr uint32_t cb_id_out0 = tt::CBIndex::c_16; + constexpr uint32_t cb_id_out1 = tt::CBIndex::c_17; + + write_tiles(num_tiles, dst_addr_0, bank_id_0, cb_id_out0); + write_tiles(num_tiles, dst_addr_1, bank_id_1, cb_id_out1); +} diff --git a/tests/ttnn/distributed/test_data_parallel_example_TG.py b/tests/ttnn/distributed/test_data_parallel_example_TG.py index 5486e7d8276..66b8bcacb5b 100644 --- a/tests/ttnn/distributed/test_data_parallel_example_TG.py +++ b/tests/ttnn/distributed/test_data_parallel_example_TG.py @@ -28,6 +28,7 @@ def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: @pytest.mark.parametrize("mesh_device", [pytest.param((1, 4), id="1x4_grid")], indirect=True) def test_data_parallel_falcon_mlp(mesh_device): + torch.manual_seed(0) # Load Falcon MLP model from huggingface config = transformers.FalconConfig.from_pretrained("tiiuae/falcon-7b-instruct") model = transformers.models.falcon.modeling_falcon.FalconMLP(config).eval() diff --git a/tests/ttnn/distributed/test_distributed_reshape.cpp b/tests/ttnn/distributed/test_distributed_reshape.cpp index a537c45eb8a..9b84cb3fec0 100644 --- a/tests/ttnn/distributed/test_distributed_reshape.cpp +++ b/tests/ttnn/distributed/test_distributed_reshape.cpp @@ -83,7 +83,7 @@ TEST_P(MeshReshapeTest, TestReshapeBetweenConfigurations) { EXPECT_EQ(mesh->num_rows(), old_shape.num_rows); EXPECT_EQ(mesh->num_cols(), old_shape.num_cols); - auto original_order = get_physical_device_ids(*mesh); + auto original_order = mesh->get_device_ids(); // Attempt reshape mesh->reshape({new_shape.num_rows, new_shape.num_cols}); @@ -93,7 +93,7 @@ TEST_P(MeshReshapeTest, TestReshapeBetweenConfigurations) { EXPECT_EQ(mesh->num_cols(), new_shape.num_cols); // Verify device ordering is preserved - EXPECT_EQ(get_physical_device_ids(*mesh), original_order); + EXPECT_EQ(mesh->get_device_ids(), original_order); } // Generate all possible combinations of shapes from kMeshShapes @@ -121,35 +121,34 @@ TEST_F(T3000ReshapeTest, InvalidReshapeDimensions) { EXPECT_EQ(mesh->num_cols(), 8); } -TEST_F(T3000ReshapeTest, From1x8To2x4) { +TEST_F(T3000ReshapeTest, From1x8To2x4ThenBackTo1x8) { auto mesh = ttnn::distributed::open_mesh_device( {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); EXPECT_EQ(mesh->num_rows(), 1); EXPECT_EQ(mesh->num_cols(), 8); - auto original_order = get_physical_device_ids(*mesh); - - mesh->reshape({2, 4}); - EXPECT_EQ(mesh->num_rows(), 2); - EXPECT_EQ(mesh->num_cols(), 4); - auto new_order = get_physical_device_ids(*mesh); - EXPECT_EQ(original_order, new_order); -} - -TEST_F(T3000ReshapeTest, OnRingTopology) { - auto mesh = ttnn::distributed::open_mesh_device( - {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); - - EXPECT_EQ(mesh->num_rows(), 1); - EXPECT_EQ(mesh->num_cols(), 8); - auto original_order = get_physical_device_ids(*mesh); + auto original_order = mesh->get_device_ids(); mesh->reshape({2, 4}); EXPECT_EQ(mesh->num_rows(), 2); EXPECT_EQ(mesh->num_cols(), 4); - auto new_order = get_physical_device_ids(*mesh); - EXPECT_EQ(original_order, new_order); + std::vector expected_physical_device_id_order = { + original_order[0], + original_order[1], + original_order[2], + original_order[3], + original_order[7], + original_order[6], + original_order[5], + original_order[4], + }; + + auto new_order = mesh->get_device_ids(); + EXPECT_EQ(new_order, expected_physical_device_id_order); + + mesh->reshape({1, 8}); + EXPECT_EQ(mesh->get_device_ids(), original_order); } TEST_F(T3000ReshapeTest, InvalidTotalDeviceCount) { @@ -165,26 +164,6 @@ TEST_F(T3000ReshapeTest, InvalidTotalDeviceCount) { EXPECT_EQ(mesh->num_cols(), 8); } -TEST_F(T3000ReshapeTest, MultipleReshapes) { - auto mesh = ttnn::distributed::open_mesh_device( - {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); - - auto original_order = get_physical_device_ids(*mesh); - - // Test multiple reshapes - mesh->reshape({2, 4}); // 1x8 -> 2x4 - auto order1 = get_physical_device_ids(*mesh); - EXPECT_EQ(order1, original_order); - - mesh->reshape({4, 2}); // 2x4 -> 4x2 - auto order2 = get_physical_device_ids(*mesh); - EXPECT_EQ(order2, original_order); - - mesh->reshape({1, 8}); // 4x2 -> 1x8 (back to original) - auto final_order = get_physical_device_ids(*mesh); - EXPECT_EQ(final_order, original_order); -} - TEST_F(T3000ReshapeTest, RingPreservation) { auto mesh = ttnn::distributed::open_mesh_device( {1, 8}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); @@ -239,7 +218,7 @@ TEST_F(T3000ReshapeTest, From1x4To2x2Valid) { mesh->reshape({2, 2}); EXPECT_EQ(mesh->num_rows(), 2); EXPECT_EQ(mesh->num_cols(), 2); - auto new_layout = get_physical_device_ids(*mesh); + auto new_layout = mesh->get_device_ids(); for (auto physical_device_id : physical_device_ids) { EXPECT_TRUE(std::find(new_layout.begin(), new_layout.end(), physical_device_id) != new_layout.end()); } @@ -249,27 +228,21 @@ TEST_F(T3000ReshapeTest, From2x2To1x4) { auto mesh = ttnn::distributed::open_mesh_device( {2, 2}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, 1, tt::tt_metal::DispatchCoreType::WORKER); - std::vector original_layout; - for (size_t i = 0; i < mesh->num_rows(); ++i) { - for (size_t j = 0; j < mesh->num_cols(); ++j) { - auto id = mesh->get_device(i, j)->id(); - original_layout.push_back(id); - } - } + auto mesh_2x2_device_ids = mesh->get_device_ids(); mesh->reshape({1, 4}); EXPECT_EQ(mesh->num_rows(), 1); EXPECT_EQ(mesh->num_cols(), 4); - std::vector new_layout; - for (size_t i = 0; i < mesh->num_rows(); ++i) { - for (size_t j = 0; j < mesh->num_cols(); ++j) { - auto id = mesh->get_device(i, j)->id(); - new_layout.push_back(id); - } - } + auto mesh_1x4_device_ids = mesh->get_device_ids(); + std::vector expected_1x4_device_ids = { + mesh_2x2_device_ids[0], + mesh_2x2_device_ids[1], + mesh_2x2_device_ids[3], + mesh_2x2_device_ids[2], + }; - EXPECT_EQ(new_layout, original_layout); + EXPECT_EQ(mesh_1x4_device_ids, expected_1x4_device_ids); } } // namespace ttnn::distributed::test diff --git a/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py b/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py similarity index 69% rename from tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py rename to tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py index 78ff318c5d8..bca11c9e2e8 100644 --- a/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50_new.py +++ b/tests/ttnn/integration_tests/resnet/test_ttnn_functional_resnet50.py @@ -1,14 +1,16 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import pytest import ttnn +from loguru import logger + from models.demos.ttnn_resnet.tests.resnet50_test_infra import create_test_infra from models.utility_functions import ( - is_wormhole_b0, - enable_memory_reports, + run_for_blackhole, + is_blackhole, ) @@ -16,16 +18,8 @@ @pytest.mark.parametrize( "batch_size, act_dtype, weight_dtype, math_fidelity", ( - ( - 8, - ttnn.bfloat8_b, - ttnn.bfloat8_b, - ttnn.MathFidelity.LoFi, - ), ## memory config issue due to l4m1 downsample reshard (16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2), (16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi), - (20, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.HiFi2), - (20, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi), ), ) @pytest.mark.parametrize( @@ -46,13 +40,14 @@ def test_resnet_50( use_pretrained_weight, model_location_generator, ): - if batch_size == 8: - pytest.skip("Skipping batch size 8 due to memory config issue") - if is_wormhole_b0() and batch_size == 20: - pytest.skip("Skipping batch size 20 for Wormhole B0 due to fitting issue") if (device.compute_with_storage_grid_size().x, device.compute_with_storage_grid_size().y) == (8, 7): pytest.skip("Test is not supported on n300 (8,7) grid") + if is_blackhole() and use_pretrained_weight: + pytest.skip( + "Skipping pretrained weight test on blackhole due to PCC error: https://github.com/tenstorrent/tt-metal/issues/17558" + ) + test_infra = create_test_infra( device, batch_size, @@ -62,7 +57,6 @@ def test_resnet_50( use_pretrained_weight, model_location_generator=model_location_generator, ) - enable_memory_reports() tt_inputs_host, input_mem_config = test_infra.setup_l1_sharded_input(device) test_infra.input_tensor = tt_inputs_host.to(device, input_mem_config) # First run configures convs JIT @@ -70,7 +64,8 @@ def test_resnet_50( # Optimized run test_infra.input_tensor = tt_inputs_host.to(device, input_mem_config) test_infra.run() - # # More optimized run with caching + # More optimized run with caching test_infra.input_tensor = tt_inputs_host.to(device, input_mem_config) test_infra.run() - test_infra.validate() + passed, message = test_infra.validate() + assert passed, message diff --git a/tests/ttnn/unit_tests/gtests/CMakeLists.txt b/tests/ttnn/unit_tests/gtests/CMakeLists.txt index a76b5284298..931739e9e6b 100644 --- a/tests/ttnn/unit_tests/gtests/CMakeLists.txt +++ b/tests/ttnn/unit_tests/gtests/CMakeLists.txt @@ -17,6 +17,7 @@ set(TTNN_CCL_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_commands.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_helpers.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_tensor_slicers.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_sharded_address_generators_new.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_sharded_address_generators.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ccl/test_ccl_reduce_scatter_host_helpers.cpp ) @@ -30,7 +31,6 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_mesh_tensor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_tensor_sharding.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_vector_conversion.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_xtensor_conversion.cpp diff --git a/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp b/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp index 8fdcfc5e302..717791c746c 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp @@ -115,6 +115,7 @@ void kernel_main() { sync_noc_x, sync_noc_y, start_sync_val); + noc_async_writes_flushed(); line_sync( fabric_connection, mcast_fwd_packet_header, @@ -190,7 +191,7 @@ void kernel_main() { fabric_conn.wait_for_empty_write_slot(); fabric_conn.send_payload_without_header_non_blocking_from_address( source_l1_buffer_address, packet_payload_size_bytes); - fabric_conn.send_payload_flush_blocking_from_address( + fabric_conn.send_payload_blocking_from_address( (uint32_t)unicast_packet_header, sizeof(tt::fabric::PacketHeader)); } @@ -203,6 +204,13 @@ void kernel_main() { sync_noc_x, sync_noc_y, finish_sync_val); + + if (sync_noc_x == my_x[0] && sync_noc_y == my_y[0]) { + // reset the global semaphore in case it is used in a op/kernel + // invocation + *reinterpret_cast(sync_bank_addr) = 0; + ; + } } { diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index 410c3206ee5..78cf7ebcab3 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -224,12 +224,15 @@ std::tuple, std::vector> build_input_buffer( return {local_input_buffer, inputs}; } -static void build_and_enqueue(const std::vector& devices, std::vector& programs) { +static void build_and_enqueue( + const std::vector& devices, std::vector& programs, bool enqueue_only = false) { TT_FATAL( devices.size() == programs.size(), "Number of devices must match number of programs when calling build_and_enqueue in test"); - for (size_t i = 0; i < devices.size(); i++) { - tt::tt_metal::detail::CompileProgram(devices[i], programs[i]); + if (!enqueue_only) { + for (size_t i = 0; i < devices.size(); i++) { + tt::tt_metal::detail::CompileProgram(devices[i], programs[i]); + } } for (size_t i = 0; i < devices.size(); i++) { tt_metal::EnqueueProgram(devices[i]->command_queue(), programs[i], false); @@ -1410,13 +1413,13 @@ bool TestMultiInputReaderKernel( // multicast to a consistent destination address for (size_t i = 0; i < devices.size(); i++) { input0_tensors_device.push_back( - input_tensor0.to(devices.at(i), input_tensor0_mem_config, ttnn::DefaultQueueId)); + input_tensor0.to_device(devices.at(i), input_tensor0_mem_config, ttnn::DefaultQueueId)); input1_tensors_device.push_back( - input_tensor1.to(devices.at(i), input_tensor1_mem_config, ttnn::DefaultQueueId)); + input_tensor1.to_device(devices.at(i), input_tensor1_mem_config, ttnn::DefaultQueueId)); output0_tensors_device.push_back( - output_tensor0.to(devices.at(i), output_tensor0_mem_config, ttnn::DefaultQueueId)); + output_tensor0.to_device(devices.at(i), output_tensor0_mem_config, ttnn::DefaultQueueId)); output1_tensors_device.push_back( - output_tensor1.to(devices.at(i), output_tensor1_mem_config, ttnn::DefaultQueueId)); + output_tensor1.to_device(devices.at(i), output_tensor1_mem_config, ttnn::DefaultQueueId)); } TT_FATAL( !enable_persistent_fabric || subdevice_managers.has_value(), @@ -1681,9 +1684,10 @@ bool RunMultiInputReaderTestPropagateFullTensorIn( TwoInputReaderKernelWriteMode test_writeback_mode) { auto num_elems = std::reduce(tensor_shape.cbegin(), tensor_shape.cend(), 1, std::multiplies()); Tensor input_tensor0 = - ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout); + ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to_layout(layout); Tensor input_tensor1 = - ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout); + ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape) + .to_layout(layout); Tensor output_tensor0 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape); Tensor output_tensor1 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape); input_tensor0.set_tensor_spec(TensorSpec( @@ -1972,9 +1976,10 @@ TEST(WorkerCclCommandProcessingKernelFabricUnicastMode, MultiInputReader_SingleP auto num_elems = std::reduce(tensor_shape.cbegin(), tensor_shape.cend(), 1, std::multiplies()); Tensor input_tensor0 = - ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout); + ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to_layout(layout); Tensor input_tensor1 = - ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout); + ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape) + .to_layout(layout); Tensor output_tensor0 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape); Tensor output_tensor1 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape); @@ -2048,9 +2053,10 @@ void RunFabricMcastFullTensorPropagateTest( auto num_elems = std::reduce(tensor_shape.cbegin(), tensor_shape.cend(), 1, std::multiplies()); Tensor input_tensor1 = - ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape).to(layout); + ttnn::experimental::view(ttnn::arange(num_elems, 2 * num_elems, 1, DataType::UINT32), tensor_shape) + .to_layout(layout); Tensor input_tensor0 = - ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout); + ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to_layout(layout); Tensor output_tensor1 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape); Tensor output_tensor0 = ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape); input_tensor0.set_tensor_spec(TensorSpec( @@ -2231,7 +2237,7 @@ bool RunPipelinedWorkersTest( device_tensors.reserve(num_tensors); auto num_elems = std::reduce(tensor_shape.cbegin(), tensor_shape.cend(), 1, std::multiplies()); host_tensors.push_back( - ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to(layout)); + ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::UINT32), tensor_shape).to_layout(layout)); for (size_t i = 1; i < num_tensors; ++i) { host_tensors.push_back( ttnn::experimental::view(ttnn::ones(tensor_shape, DataType::UINT32, layout), tensor_shape)); @@ -2239,7 +2245,7 @@ bool RunPipelinedWorkersTest( TT_FATAL(mem_configs.size() == num_tensors, "Must have a memory config for each tensor"); for (size_t i = 0; i < num_tensors; i++) { host_tensors[i].set_tensor_spec(tensor_specs[i]); - device_tensors.push_back(host_tensors[i].to(device, mem_configs[i])); + device_tensors.push_back(host_tensors[i].to_device(device, mem_configs[i])); log_info("Tensor[{}] allocated starting at address {}", i, device_tensors[i].buffer()->address()); } TT_ASSERT(device_tensors.size() == num_tensors); @@ -2748,12 +2754,13 @@ TEST(CclAsyncOp, ReduceScatterSmall_PersistentFabric) { std::vector device_input_tensors; for (size_t i = 0; i < num_devices; i++) { // host_input_tensors.push_back(ttnn::numpy::random::uniform(bfloat16(-1.0f), bfloat16(1.0f) , - // {input_shape[0],input_shape[1],input_shape[2],input_shape[3]}, layout).to(devices[i])); - auto t = ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::BFLOAT16), input_shape).to(layout); + // {input_shape[0],input_shape[1],input_shape[2],input_shape[3]}, layout).to_device(devices[i])); + auto t = + ttnn::experimental::view(ttnn::arange(0, num_elems, 1, DataType::BFLOAT16), input_shape).to_layout(layout); t.set_tensor_spec(TensorSpec( input_shape, TensorLayout(DataType::BFLOAT16, PageConfig(layout, tt_metal::Tile()), in_memory_config))); - device_input_tensors.push_back(t.to(devices[i])); + device_input_tensors.push_back(t.to_device(devices[i])); } // Need to make it a mesh tensor for use with the op const Tensor input_mesh_tensor = ttnn::distributed::aggregate_as_tensor(device_input_tensors, AllGatherTensor{}); @@ -2866,11 +2873,11 @@ void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_li size_t page_size = tile_size(DataFormat::Float16); std::vector device_input_tensors; for (size_t i = 0; i < num_devices; i++) { - auto t = ttnn::experimental::view(ttnn::arange(0, num_elems, 1), input_shape).to(layout); + auto t = ttnn::experimental::view(ttnn::arange(0, num_elems, 1), input_shape).to_layout(layout); t.set_tensor_spec(TensorSpec( input_shape, TensorLayout(DataType::BFLOAT16, PageConfig(layout, tt_metal::Tile()), in_memory_config))); - device_input_tensors.push_back(t.to(devices[i])); + device_input_tensors.push_back(t.to_device(devices[i])); } // Need to make it a mesh tensor for use with the op const Tensor input_mesh_tensor = ttnn::distributed::aggregate_as_tensor(device_input_tensors, AllGatherTensor{}); @@ -2946,7 +2953,7 @@ TEST(CclAsyncOp, DISABLED_AllGather_PersistentFabric_Dim3_Links2_Shape1_1_32_819 struct WriteThroughputStabilityTestWithPersistentFabricParams { size_t line_size = 4; size_t num_devices_with_workers = 0; - bool line_sync = false; + bool line_sync = true; }; void RunWriteThroughputStabilityTestWithPersistentFabric( @@ -2974,10 +2981,6 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( using namespace ttnn::ccl; TT_FATAL(num_devices_with_workers <= line_size, "num_devices_with_workers must be less than or equal to num_links"); - if (params.line_sync) { - TT_FATAL(num_op_invocations == 1, "Performance reporting only supported for 1 invocation per test"); - } - auto worker_core_logical = [](size_t link) { return CoreCoord(link, 0); }; // static constexpr size_t source_l1_buffer_address = 1000000; @@ -3040,15 +3043,22 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( [dest_bank_addr](const auto& buffer) { return buffer->address() == dest_bank_addr; }), "Test setup error: all destination buffers must have the same bank address across devices"); - auto global_semaphores = ttnn::global_semaphore::create_global_semaphore_with_same_address( - test_fixture.mesh_device_.get(), - devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), - 0, // initial value - tt::tt_metal::BufferType::L1, // buffer type - 1000 // attempts - ); - auto global_semaphore_addr = - ttnn::global_semaphore::get_global_semaphore_address(global_semaphores.global_semaphores.at(0)); + std::vector global_semaphore_addrs; + global_semaphore_addrs.reserve(line_size + 1); + std::vector global_semaphore_handles; + for (size_t i = 0; i < line_size * 4; i++) { + auto global_semaphores = ttnn::global_semaphore::create_global_semaphore_with_same_address( + test_fixture.mesh_device_.get(), + devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), + 0, // initial value + tt::tt_metal::BufferType::L1, // buffer type + 1000 // attempts + ); + global_semaphore_handles.push_back(global_semaphores); + auto global_semaphore_addr = + ttnn::global_semaphore::get_global_semaphore_address(global_semaphores.global_semaphores.at(0)); + global_semaphore_addrs.push_back(global_semaphore_addr); + } std::vector worker_devices; for (size_t i = 0; i < num_devices_with_workers; i++) { @@ -3062,6 +3072,8 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( "instead.", line_size, worker_devices.size()); + std::vector worker_kernel_ids; + std::vector per_device_global_sem_addr_rt_arg; for (size_t i = 0; i < num_devices_with_workers; i++) { const size_t line_index = i; auto& program = programs[i]; @@ -3113,6 +3125,7 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( "tests/ttnn/unit_tests/gtests/ccl/kernels/edm_fabric_writer.cpp", worker_cores, tt_metal::WriterDataMovementConfig(worker_ct_args)); + worker_kernel_ids.push_back(worker_kernel_id); for (size_t l = 0; l < num_links; l++) { auto worker_core = worker_cores_vec[l]; auto build_connection_args = [&local_device_fabric_handle, device, &program, &worker_core]( @@ -3160,8 +3173,12 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( if (params.line_sync) { rt_args.push_back(sync_core_noc_x); rt_args.push_back(sync_core_noc_y); - rt_args.push_back(global_semaphore_addr); - rt_args.push_back(num_links * num_devices_with_workers /*line_size*/); + if (l == 0) { + per_device_global_sem_addr_rt_arg.push_back(rt_args.size()); + } + TT_FATAL(global_semaphore_addrs.at(0) != -1, "Invalid test setup. Global semaphore address is -1"); + rt_args.push_back(global_semaphore_addrs.at(0)); + rt_args.push_back(num_links * num_devices_with_workers); } tt_metal::SetRuntimeArgs(program, worker_kernel_id, worker_core, rt_args); @@ -3170,7 +3187,19 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( for (size_t i = 0; i < num_op_invocations; i++) { log_info(tt::LogTest, "Iteration: {}", i); - build_and_enqueue(worker_devices, programs); + if (i != 0 && params.line_sync) { + for (size_t k = 0; k < worker_kernel_ids.size(); k++) { + auto& worker_rt_args_by_core = GetRuntimeArgs(programs[k], worker_kernel_ids[k]); + auto global_sem_addr_rt_arg_idx = per_device_global_sem_addr_rt_arg[k]; + for (size_t l = 0; l < num_links; l++) { + auto& worker_rt_args = worker_rt_args_by_core[worker_cores_vec[l].x][worker_cores_vec[l].y]; + worker_rt_args.at(global_sem_addr_rt_arg_idx) = + global_semaphore_addrs[i % global_semaphore_addrs.size()]; + } + } + } + + build_and_enqueue(worker_devices, programs, i != 0); log_info(tt::LogTest, "Waiting for Op finish on all devices"); wait_for_worker_subdevice_program_completion(worker_devices, subdevice_managers); @@ -3180,7 +3209,7 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( TT_FATAL(fabric_programs->size() == devices.size(), "Expected fabric programs size to be same as devices size"); log_info(tt::LogTest, "Fabric teardown"); persistent_fabric_teardown_sequence( - devices, subdevice_managers, fabric_handle.value(), tt::fabric::TerminationSignal::IMMEDIATELY_TERMINATE); + devices, subdevice_managers, fabric_handle.value(), tt::fabric::TerminationSignal::GRACEFULLY_TERMINATE); log_info(tt::LogTest, "Waiting for teardown completion"); for (IDevice* d : devices) { @@ -3204,7 +3233,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SingleLink_LineSize2_SingleMcast) { const size_t num_unicasts = 2; const size_t num_links = 1; const size_t num_op_invocations = 1; - const bool line_sync = false; + const bool line_sync = true; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_sync = line_sync; params.line_size = 2; @@ -3217,9 +3246,66 @@ TEST(EdmFabric, BasicMcastThroughputTest_SingleMcast) { const size_t num_unicasts = 2; const size_t num_links = 2; const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_SingleWorker_2Device) { + const size_t num_mcasts = 9; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; const bool line_sync = false; WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; params.line_sync = line_sync; + params.num_devices_with_workers = 1; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +// hangs with DPRINT +TEST(EdmFabric, BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_2Device) { + const size_t num_mcasts = 9; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_SingleWorker_4Device) { + const size_t num_mcasts = 9; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 4; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + params.num_devices_with_workers = 1; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +// First to hang - maybe somethign to do with merging traffic +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap_TwoWorkers_4Device) { + const size_t num_mcasts = 9; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 4; + const bool line_sync = false; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; + params.line_sync = line_sync; + params.num_devices_with_workers = 2; RunWriteThroughputStabilityTestWithPersistentFabric( num_mcasts, num_unicasts, num_links, num_op_invocations, params); } @@ -3228,9 +3314,23 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderFullNoWrap_ReceiverNoWrap) { const size_t num_unicasts = 0; const size_t num_links = 1; const size_t num_op_invocations = 1; + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); +} +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap_SingleWorker_2Device) { + const size_t num_mcasts = 10; + const size_t num_unicasts = 0; + const size_t num_links = 1; + const size_t num_op_invocations = 1; + const size_t line_size = 2; const bool line_sync = false; WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_size = line_size; params.line_sync = line_sync; + params.num_devices_with_workers = 1; RunWriteThroughputStabilityTestWithPersistentFabric( num_mcasts, num_unicasts, num_links, num_op_invocations, params); } @@ -3240,7 +3340,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap_2Devic const size_t num_links = 1; const size_t num_op_invocations = 1; const size_t line_size = 2; - const bool line_sync = false; + const bool line_sync = true; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_size = line_size; params.line_sync = line_sync; @@ -3252,7 +3352,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderOneElemWrap_ReceiverNoWrap) { const size_t num_unicasts = 0; const size_t num_links = 1; const size_t num_op_invocations = 1; - const bool line_sync = false; + const bool line_sync = true; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_sync = line_sync; RunWriteThroughputStabilityTestWithPersistentFabric( @@ -3264,7 +3364,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled_2D const size_t num_links = 1; const size_t num_op_invocations = 1; const size_t line_size = 2; - const bool line_sync = false; + const bool line_sync = true; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_size = line_size; params.line_sync = line_sync; @@ -3276,7 +3376,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled) { const size_t num_unicasts = 0; const size_t num_links = 1; const size_t num_op_invocations = 1; - const bool line_sync = false; + const bool line_sync = true; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_sync = line_sync; RunWriteThroughputStabilityTestWithPersistentFabric( @@ -3287,7 +3387,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderTwoWrap_ReceiverOneWrap) { const size_t num_unicasts = 0; const size_t num_links = 1; const size_t num_op_invocations = 1; - const bool line_sync = false; + const bool line_sync = true; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_sync = line_sync; RunWriteThroughputStabilityTestWithPersistentFabric( @@ -3376,7 +3476,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SenderTwiceFilled_ReceiverOnceFilled_Li RunWriteThroughputStabilityTestWithPersistentFabric( num_mcasts, num_unicasts, num_links, num_op_invocations, params); } -TEST(EdmFabric, BasicMcastThroughputTest_SenderFourTImesFilled_ReceiverTwiceFilled_2Device_1Worker) { +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_SenderFourTImesFilled_ReceiverTwiceFilled_2Device_1Worker) { const size_t num_mcasts = 36; const size_t num_unicasts = 0; const size_t num_links = 1; @@ -3461,7 +3561,7 @@ TEST(EdmFabric, BasicMcastThroughputTest_SmallPerf1) { num_mcasts, num_unicasts, num_links, num_op_invocations, params); } -TEST(EdmFabric, BasicMcastThroughputTest_0) { +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_0) { const size_t num_mcasts = 100; const size_t num_unicasts = 2; const size_t num_links = 2; @@ -3469,16 +3569,20 @@ TEST(EdmFabric, BasicMcastThroughputTest_0) { const bool line_sync = false; WriteThroughputStabilityTestWithPersistentFabricParams params; params.line_size = 2; + params.line_sync = line_sync; RunWriteThroughputStabilityTestWithPersistentFabric( num_mcasts, num_unicasts, num_links, num_op_invocations, params); } -TEST(EdmFabric, BasicMcastThroughputTest_1) { +TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_1) { const size_t num_mcasts = 1000; const size_t num_unicasts = 2; const size_t num_links = 2; const size_t num_op_invocations = 1; const bool line_sync = false; - RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); } TEST(EdmFabric, BasicMcastThroughputTest_2) { const size_t num_mcasts = 50000; @@ -3493,7 +3597,11 @@ TEST(EdmFabric, BasicMcastThroughputTest_3) { const size_t num_unicasts = 2; const size_t num_links = 2; const size_t num_op_invocations = 1; - RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); + const bool line_sync = true; + WriteThroughputStabilityTestWithPersistentFabricParams params; + params.line_sync = line_sync; + RunWriteThroughputStabilityTestWithPersistentFabric( + num_mcasts, num_unicasts, num_links, num_op_invocations, params); } TEST(EdmFabric, BasicMcastThroughputTest_4) { const size_t num_mcasts = 800000; @@ -3550,7 +3658,6 @@ TEST(EdmFabric, DISABLED_BasicMcastThroughputTest_10) { const size_t num_op_invocations = 50; RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); } -// DISABLED due to long runtime TEST(EdmFabric, BasicMcastThroughputTest_6_Short) { const size_t num_mcasts = 100; const size_t num_unicasts = 2; @@ -3558,7 +3665,6 @@ TEST(EdmFabric, BasicMcastThroughputTest_6_Short) { const size_t num_op_invocations = 100; RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); } -// DISABLED due to long runtime TEST(EdmFabric, BasicMcastThroughputTest_7_Short) { const size_t num_mcasts = 1000; const size_t num_unicasts = 2; @@ -3566,7 +3672,6 @@ TEST(EdmFabric, BasicMcastThroughputTest_7_Short) { const size_t num_op_invocations = 50; RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); } -// DISABLED due to long runtime TEST(EdmFabric, BasicMcastThroughputTest_8_Short) { const size_t num_mcasts = 50000; const size_t num_unicasts = 2; @@ -3574,7 +3679,6 @@ TEST(EdmFabric, BasicMcastThroughputTest_8_Short) { const size_t num_op_invocations = 20; RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); } -// DISABLED due to long runtime TEST(EdmFabric, BasicMcastThroughputTest_9_Short) { const size_t num_mcasts = 200000; const size_t num_unicasts = 2; @@ -3582,7 +3686,6 @@ TEST(EdmFabric, BasicMcastThroughputTest_9_Short) { const size_t num_op_invocations = 10; RunWriteThroughputStabilityTestWithPersistentFabric(num_mcasts, num_unicasts, num_links, num_op_invocations); } -// DISABLED due to long runtime TEST(EdmFabric, BasicMcastThroughputTest_10_Short) { const size_t num_mcasts = 800000; const size_t num_unicasts = 2; diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators_new.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators_new.cpp new file mode 100644 index 00000000000..a64d2e8a141 --- /dev/null +++ b/tests/ttnn/unit_tests/gtests/ccl/test_sharded_address_generators_new.cpp @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "gtest/gtest.h" +#if !(defined(KERNEL_BUILD) || defined(FW_BUILD)) + +#define NOC_ADDR_LOCAL_BITS 36 +#define NOC_ADDR_NODE_ID_BITS 6 + +#define FORCE_INLINE inline __attribute__((always_inline)) +#define noc_index 0 +#define NOC_XY_ADDR(x, y, addr) \ + ((((uint64_t)(y)) << (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) | (((uint64_t)(x)) << NOC_ADDR_LOCAL_BITS) | \ + ((uint64_t)(addr))) +#define NOC_0_X(noc_index, noc_size_x, x) x +#define NOC_0_Y(noc_index, noc_size_y, y) y +#define DYNAMIC_NOC_X(noc, x) NOC_0_X(noc, noc_size_x, (x)) +#define DYNAMIC_NOC_Y(noc, y) NOC_0_Y(noc, noc_size_y, (y)) + +#endif +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/sharding_addrgen.hpp" +namespace sharding_testing_parameters { +mapping_table_t map[9] = {0x00000001, 0x00020003, 0x00040200, 0x02010202, 0x02030204, 0x03000301, 0x03020303, 0x04000401, 0x04020403}; +uint64_t real_core_x_vals [18] = {0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, 0x4}; +uint64_t real_core_y_vals [18] = {0x0, 0x1, 0x2, 0x3, 0x4, 0x0, 0x1, 0x2, 0x3, 0x4, 0x0, 0x1, 0x2, 0x3, 0x0, 0x1, 0x2, 0x3}; +} // namespace sharding_testing_parameters +namespace tt { +namespace tt_metal { + +template +void run_full_width_test(ADDRgen addrgen, ADDRgenInfo constants, uint32_t bank_base_address) { + uint32_t rows[7] = {0, 1, 31, 32, 33, 66, 10000}; + for (int i = 0; i < std::size(rows); i++) { + uint32_t page = constants.pages_per_tensor_row * rows[i]; + uint32_t base_address = constants.pages_per_shard_width * rows[i] * constants.page_size_jump; + for (int j = 0; j < constants.number_of_cores; j++) { + uint64_t l1_address = base_address; + for (int k = 0; k < constants.pages_per_shard_width; k++) { + if (j * constants.pages_per_shard_width + k < constants.pages_per_tensor_row) { + uint64_t calculated_address = + bank_base_address + NOC_XY_ADDR( + DYNAMIC_NOC_X(noc, sharding_testing_parameters::real_core_x_vals[j]), + DYNAMIC_NOC_Y(noc, sharding_testing_parameters::real_core_y_vals[j]), + l1_address); + uint64_t retrieved_address = addrgen.get_noc_addr(page); + ASSERT_EQ(calculated_address, retrieved_address); + l1_address += constants.page_size_jump; + } + page++; + } + } + } +} + +template +void run_full_height_test(ADDRgen addrgen, ADDRgenInfo constants, uint32_t bank_base_address) { + uint32_t width_pages[5] = {0, 1, 31, 30, 14}; + for (int i = 0; i < std::size(width_pages); i++) { + uint32_t page = width_pages[i]; + uint32_t base_address = page * constants.page_size_jump; + for (int j = 0; j < constants.number_of_cores; j++) { + uint32_t l1_address = base_address; + for (int k = 0; k < constants.rows_per_shard_height; k++) { + uint64_t calculated_address = + bank_base_address + NOC_XY_ADDR( + DYNAMIC_NOC_X(noc, sharding_testing_parameters::real_core_x_vals[j]), + DYNAMIC_NOC_Y(noc, sharding_testing_parameters::real_core_y_vals[j]), + l1_address); + uint64_t retrieved_address = addrgen.get_noc_addr(page); + ASSERT_EQ(calculated_address, retrieved_address); + l1_address += constants.page_size_jump * constants.pages_per_tensor_row; + page = page + constants.pages_per_tensor_row; + } + } + } +} + +template +void run_full_block_test(ADDRgen addrgen, ADDRgenInfo constants, uint32_t bank_base_address) { + uint32_t random_width_offsets[4] = {0, 1, 5, 7}; + uint32_t random_height_offsets[4] = {0, 1, 5, 7}; + uint32_t cores_per_block_row = (constants.pages_per_tensor_row - 1) / constants.pages_per_shard_width + 1; + uint32_t cores_height = constants.number_of_cores / cores_per_block_row; + for (int i = 0; i < std::size(random_width_offsets); i++) { + for (int j = 0; j < std::size(random_height_offsets); j++) { + uint64_t outer_page = random_width_offsets[i] + random_height_offsets[j] * constants.pages_per_tensor_row; + uint64_t l1_address = + (random_width_offsets[i] + random_height_offsets[j] * constants.pages_per_shard_width) * + constants.page_size_jump; + for (int h = 0; h < cores_height; h++) { + uint64_t page = outer_page; + for (int w = 0; w < cores_per_block_row; w++) { + uint32_t core_number = w + h * cores_per_block_row; + uint64_t calculated_address = + bank_base_address + + NOC_XY_ADDR( + DYNAMIC_NOC_X(noc, sharding_testing_parameters::real_core_x_vals[core_number]), + DYNAMIC_NOC_Y(noc, sharding_testing_parameters::real_core_y_vals[core_number]), + l1_address); + uint64_t retrieved_address = addrgen.get_noc_addr(page); + ASSERT_EQ(calculated_address, retrieved_address); + page += constants.pages_per_shard_width; + } + outer_page += constants.pages_per_tensor_row * constants.rows_per_shard_height; + } + } + } +} + +TEST(CclnewWidthShardedTensorSliceIndexer_Wormhole, width_sharded_test) { + constexpr std::size_t shard_type = static_cast(tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED); + constexpr std::size_t number_of_cores = 8; + constexpr std::size_t page_size_jump = 1024; + constexpr std::size_t pages_per_tensor_row = 32; + constexpr std::size_t contiguity = static_cast(shard_addr_gen_consts::ContiguityType::NO_SHARD_PADDING); + constexpr std::size_t pages_per_shard_width = 6; + constexpr std::size_t rows_per_shard_height = 1; + constexpr std::size_t tensor_address = 0x100000; + using ct_shard_info = ShardedInfo< + shard_type, + number_of_cores, + page_size_jump, + pages_per_tensor_row, + contiguity, + pages_per_shard_width, + rows_per_shard_height>; + auto info_var = ct_shard_info{}; + experimental::ShardedAddrGen addrgen = { + .bank_base_address = tensor_address, .shard_array = sharding_testing_parameters::map}; + run_full_width_test(addrgen, info_var, tensor_address); +} + +TEST(CclnewHeightShardedTensorSliceIndexer_Wormhole, height_sharded_test) { + static constexpr std::size_t shard_type = static_cast(tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED); + static constexpr std::size_t number_of_cores = 4; + static constexpr std::size_t page_size_jump = 1024; + static constexpr std::size_t pages_per_tensor_row = 32; + static constexpr std::size_t contiguity = + static_cast(shard_addr_gen_consts::ContiguityType::NO_SHARD_PADDING); + static constexpr std::size_t pages_per_shard_width = 1; + static constexpr std::size_t rows_per_shard_height = 8; + static constexpr std::size_t tensor_address = 0x100000; + typedef ShardedInfo< + shard_type, + number_of_cores, + page_size_jump, + pages_per_tensor_row, + contiguity, + pages_per_shard_width, + rows_per_shard_height> + ct_shard_info; + auto info_var = ct_shard_info{}; + experimental::ShardedAddrGen addrgen = { + .bank_base_address = tensor_address, .shard_array = sharding_testing_parameters::map}; + run_full_height_test(addrgen, info_var, tensor_address); +} + +TEST(CclnewBlockShardedTensorSliceIndexer_Wormhole, block_sharded_test) { + static constexpr std::size_t shard_type = static_cast(tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED); + static constexpr std::size_t number_of_cores = 16; + static constexpr std::size_t page_size_jump = 1024; + static constexpr std::size_t pages_per_tensor_row = 32; + static constexpr std::size_t contiguity = + static_cast(shard_addr_gen_consts::ContiguityType::NO_SHARD_PADDING); + static constexpr std::size_t pages_per_shard_width = 8; + static constexpr std::size_t rows_per_shard_height = 8; + static constexpr std::size_t tensor_address = 0x1000000; + typedef ShardedInfo< + shard_type, + number_of_cores, + page_size_jump, + pages_per_tensor_row, + contiguity, + pages_per_shard_width, + rows_per_shard_height> + ct_shard_info; + auto info_var = ct_shard_info{}; + experimental::ShardedAddrGen addrgen = { + .bank_base_address = tensor_address, .shard_array = sharding_testing_parameters::map}; + run_full_block_test(addrgen, info_var, tensor_address); +} + +} // namespace tt_metal +} // namespace tt diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp index e9243c91a17..4e667b33727 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp @@ -5,7 +5,11 @@ #include #include +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/tensor/tensor_impl_wrapper.hpp" #include "ttnn_test_fixtures.hpp" #include #include @@ -13,6 +17,9 @@ namespace ttnn::distributed::test { namespace { +using ::testing::FloatEq; +using ::testing::Pointwise; + using MeshTensorTest = T3kMultiDeviceFixture; TEST_F(MeshTensorTest, Lifecycle) { @@ -43,5 +50,97 @@ TEST_F(MeshTensorTest, Lifecycle) { EXPECT_FALSE(input_tensor.is_allocated()); } +using MeshTensorDeviceTest = T3kMultiDeviceFixture; + +TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) { + const ttnn::Shape shape{1, 1, 32, 32}; + const TensorSpec tensor_spec = + TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + Tensor input_host_tensor = Tensor::from_vector(std::vector(shape.volume()), tensor_spec); + EXPECT_TRUE(input_host_tensor.storage_type() == StorageType::OWNED); + + EXPECT_ANY_THROW(tensor_impl::to_host_mesh_tensor_wrapper(input_host_tensor)); +} + +TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) { + const ttnn::Shape shape{1, 1, 32, 32}; + const TensorSpec tensor_spec = + TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + + std::vector host_data(shape.volume()); + std::iota(host_data.begin(), host_data.end(), 0); + + // Prepare host tensor to offload on device. + Tensor input_host_tensor = Tensor::from_vector(host_data, tensor_spec); + EXPECT_TRUE(input_host_tensor.storage_type() == StorageType::OWNED); + EXPECT_EQ(input_host_tensor.get_tensor_spec().logical_shape(), shape); + + // Write host tensor to device. + Tensor device_tensor = + tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor, mesh_device_.get(), MemoryConfig{}); + EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor)); + EXPECT_EQ(device_tensor.get_tensor_spec().logical_shape(), shape); + + auto* multi_device_storage = std::get_if(&device_tensor.get_storage()); + ASSERT_NE(multi_device_storage, nullptr); + for (const auto& [_, shard_spec] : multi_device_storage->specs) { + EXPECT_EQ(shard_spec.logical_shape(), shape); + } + EXPECT_TRUE(std::holds_alternative(multi_device_storage->strategy)); + + // Read the tensor back, and compare it with input data. + Tensor output_host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor); + EXPECT_TRUE(output_host_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST); + EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape); + + for (const auto& tensor : get_tensors_from_multi_device_storage(output_host_tensor)) { + EXPECT_EQ(tensor.get_tensor_spec().logical_shape(), shape); + EXPECT_THAT(tensor.to_vector(), Pointwise(FloatEq(), host_data)); + } +} + +TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) { + const int num_devices = mesh_device_->num_devices(); + ASSERT_EQ(num_devices, 8); + // Test uneven shard shapes. + const ttnn::Shape shape{1, 9, 32, 32}; + const TensorSpec tensor_spec = + TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{})); + + std::vector host_data(shape.volume()); + std::iota(host_data.begin(), host_data.end(), 0); + + // Prepare multi-device host tensor to offload on device. + Tensor input_host_tensor_sharded = distribute_tensor( + Tensor::from_vector(host_data, tensor_spec), *shard_tensor_to_mesh_mapper(*mesh_device_, /*dim=*/1)); + EXPECT_TRUE(input_host_tensor_sharded.storage_type() == StorageType::MULTI_DEVICE_HOST); + + auto* multi_device_host_storage = + std::get_if(&input_host_tensor_sharded.get_storage()); + ASSERT_NE(multi_device_host_storage, nullptr); + const auto* strategy = std::get_if(&multi_device_host_storage->strategy); + ASSERT_NE(strategy, nullptr); + EXPECT_EQ(strategy->shard_dimension, 1); + + // Write host tensor to device. + Tensor device_tensor = + tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{}); + EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor)); + + auto* multi_device_storage = std::get_if(&device_tensor.get_storage()); + ASSERT_NE(multi_device_storage, nullptr); + const auto* device_tensor_strategy = std::get_if(&multi_device_storage->strategy); + ASSERT_NE(device_tensor_strategy, nullptr); + EXPECT_EQ(device_tensor_strategy->shard_dimension, 1); + + // Read the tensor back, and compare it with input data. + Tensor output_host_tensor = aggregate_tensor( + tensor_impl::to_host_mesh_tensor_wrapper(device_tensor), *concat_mesh_to_tensor_composer(/*dim=*/1)); + EXPECT_TRUE(output_host_tensor.storage_type() == StorageType::OWNED); + EXPECT_EQ(output_host_tensor.get_tensor_spec().logical_shape(), shape); + + EXPECT_THAT(output_host_tensor.to_vector(), Pointwise(FloatEq(), host_data)); +} + } // namespace } // namespace ttnn::distributed::test diff --git a/tests/ttnn/unit_tests/gtests/test_graph_add.cpp b/tests/ttnn/unit_tests/gtests/test_graph_add.cpp index b3b90cffc0d..60a638757f4 100644 --- a/tests/ttnn/unit_tests/gtests/test_graph_add.cpp +++ b/tests/ttnn/unit_tests/gtests/test_graph_add.cpp @@ -155,9 +155,25 @@ INSTANTIATE_TEST_SUITE_P( .expected_calltrace = {"ttnn::add", "ttnn::repeat", + "ttnn::to_layout", + "ttnn::untilize", + "ttnn::prim::old_infra_device_operation", + "Untilize", + "tt::tt_metal::create_device_tensor", + "ttnn::view", + "ttnn::experimental::view", + "Tensor::reshape", "ttnn::prim::old_infra_device_operation", "RepeatDeviceOperation", "tt::tt_metal::create_device_tensor", + "ttnn::view", + "ttnn::experimental::view", + "Tensor::reshape", + "ttnn::to_layout", + "ttnn::tilize", + "ttnn::prim::old_infra_device_operation", + "Tilize", + "tt::tt_metal::create_device_tensor", "ttnn::prim::binary", "BinaryDeviceOperation", "tt::tt_metal::create_device_tensor"}, diff --git a/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp b/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp index a00f0118aba..a2011ad7bab 100644 --- a/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp +++ b/tests/ttnn/unit_tests/gtests/test_graph_query_op_constraints.cpp @@ -414,13 +414,13 @@ INSTANTIATE_TEST_SUITE_P( ResourceUsageMap{ {BoardType::N300, ttnn::graph::ResourceUsage{ - .cb_peak_size_per_core = 3 * (2 * 2 * 32 * 32), - .l1_buffers_peak_per_core = 20480, + .cb_peak_size_per_core = 57344, + .l1_buffers_peak_per_core = 26688, .l1_output_buffer_per_core = 10240}}, {BoardType::E150, ttnn::graph::ResourceUsage{ - .cb_peak_size_per_core = 3 * (2 * 2 * 32 * 32), - .l1_buffers_peak_per_core = 12288, + .cb_peak_size_per_core = 57344, + .l1_buffers_peak_per_core = 14720, .l1_output_buffer_per_core = 6144}}}), std::make_tuple( // broadcast g_interleave_4_2_160_244_tiled, @@ -428,13 +428,13 @@ INSTANTIATE_TEST_SUITE_P( ResourceUsageMap{ {BoardType::N300, ttnn::graph::ResourceUsage{ - .cb_peak_size_per_core = 3 * (2 * 2 * 32 * 32), - .l1_buffers_peak_per_core = 20480, + .cb_peak_size_per_core = 57344, + .l1_buffers_peak_per_core = 26688, .l1_output_buffer_per_core = 10240}}, {BoardType::E150, ttnn::graph::ResourceUsage{ - .cb_peak_size_per_core = 3 * (2 * 2 * 32 * 32), - .l1_buffers_peak_per_core = 12288, + .cb_peak_size_per_core = 57344, + .l1_buffers_peak_per_core = 14720, .l1_output_buffer_per_core = 6144}}})), [](const testing::TestParamInfo>& info) { std::stringstream ss; diff --git a/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp b/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp index 0dffd658361..1c7c33ee8aa 100644 --- a/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp @@ -54,7 +54,7 @@ TEST_F(MultiProducerCommandQueueTest, Stress) { std::thread t0([&]() { for (int j = 0; j < 100; j++) { - Tensor t0_tensor = t0_host_tensor.to(device, mem_cfg, t0_io_cq); + Tensor t0_tensor = t0_host_tensor.to_device(device, mem_cfg, t0_io_cq); EXPECT_TRUE(is_tensor_on_device(t0_tensor)); EXPECT_THAT(t0_tensor.to_vector(), Pointwise(FloatEq(), t0_host_data)); } @@ -62,7 +62,7 @@ TEST_F(MultiProducerCommandQueueTest, Stress) { std::thread t1([&]() { for (int j = 0; j < 100; j++) { - Tensor t1_tensor = t1_host_tensor.to(device, mem_cfg, t1_io_cq); + Tensor t1_tensor = t1_host_tensor.to_device(device, mem_cfg, t1_io_cq); EXPECT_TRUE(is_tensor_on_device(t1_tensor)); EXPECT_THAT(t1_tensor.to_vector(), Pointwise(FloatEq(), t1_host_data)); } diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index 405e08b046f..a476163c8d5 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -74,7 +74,9 @@ def run_with_trace( cluster_axis=cluster_axis, mesh_device=mesh_device, topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles, + multi_device_global_semaphore=ccl_semaphore_handles[0] + if type(ccl_semaphore_handles) == list + else ccl_semaphore_handles, num_links=num_links, memory_config=output_mem_config, subdevice_id=worker_sub_device_id, @@ -105,7 +107,9 @@ def run_with_trace( cluster_axis=cluster_axis, mesh_device=mesh_device, topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles, + multi_device_global_semaphore=ccl_semaphore_handles[i] + if type(ccl_semaphore_handles) == list + else ccl_semaphore_handles, num_links=num_links, memory_config=output_mem_config, subdevice_id=worker_sub_device_id, @@ -149,6 +153,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( function_level_defaults, enable_async, input_shard_spec: ttnn.ShardSpec = None, + output_shard_spec: ttnn.ShardSpec = None, num_all_gather_instances: int = 1, num_iters: int = 1, cluster_axis: int = 0, @@ -200,8 +205,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( else (num_all_gather_instances, num_devices_per_line) ) - output_shard_spec = None - if input_shard_spec is not None: + if input_shard_spec is not None and output_shard_spec is None: output_shard_shape = list(input_shard_spec.shape) if dim == len(per_chip_output_shape) - 1: output_shard_shape[1] *= num_devices_per_line @@ -223,6 +227,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dims), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device) + ttnn_tensor = ttnn.to_memory_config(ttnn_tensor, input_mem_config) sub_device_stall_group = [] if use_all_gather_async: @@ -246,60 +251,64 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - ccl_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + ccl_semaphore_handles = [ + create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) + ] + try: + # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) + if trace_mode: + ttnn_tensor_out = run_with_trace( + input_tensor=ttnn_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + output_mem_config=output_mem_config, + ccl_semaphore_handles=ccl_semaphore_handles, + worker_sub_device_id=worker_sub_device_id, + enable_persistent_fabric=enable_persistent_fabric, + all_gather_topology=ttnn.Topology.Linear, + num_iter=num_iters, + use_all_gather_async=use_all_gather_async, + ) - # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) - if trace_mode: - ttnn_tensor_out = run_with_trace( - input_tensor=ttnn_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - output_mem_config=output_mem_config, - ccl_semaphore_handles=ccl_semaphore_handles, - worker_sub_device_id=worker_sub_device_id, - enable_persistent_fabric=enable_persistent_fabric, - all_gather_topology=ttnn.Topology.Linear, - num_iter=num_iters, - use_all_gather_async=use_all_gather_async, - ) - else: - for _ in range(num_iters): - if use_all_gather_async: - logger.info("Running all-gather async") - ttnn_tensor_out = ttnn.experimental.all_gather_async( - ttnn_tensor, - dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - topology=ttnn.Topology.Linear, - multi_device_global_semaphore=ccl_semaphore_handles, - num_links=num_links, - memory_config=output_mem_config, - subdevice_id=worker_sub_device_id, - enable_persistent_fabric_mode=enable_persistent_fabric, - ) - else: - ttnn_tensor_out = ttnn.all_gather( - ttnn_tensor, - dim=dim, - cluster_axis=cluster_axis, - mesh_device=mesh_device, - num_links=num_links, - memory_config=output_mem_config, - topology=ttnn.Topology.Linear, - ) - - if enable_persistent_fabric: - ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) - ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) - - if enable_persistent_fabric and teardown_persistent_fabric: - logger.info("Tearing down persistent fabric interface") - mesh_device.reset_sub_device_stall_group() - teardown_fabric_interface(mesh_device) - logger.info("Done tearing down persistent fabric interface") + else: + for i in range(num_iters): + if use_all_gather_async: + logger.info("Running all-gather async") + ttnn_tensor_out = ttnn.experimental.all_gather_async( + ttnn_tensor, + dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + multi_device_global_semaphore=ccl_semaphore_handles[i], + num_links=num_links, + memory_config=output_mem_config, + subdevice_id=worker_sub_device_id, + enable_persistent_fabric_mode=enable_persistent_fabric, + ) + else: + ttnn_tensor_out = ttnn.all_gather( + ttnn_tensor, + dim=dim, + cluster_axis=cluster_axis, + mesh_device=mesh_device, + num_links=num_links, + memory_config=output_mem_config, + topology=ttnn.Topology.Linear, + ) + ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) + + except Exception as e: + logger.error(f"Exception: {e}") + raise e + finally: + if enable_persistent_fabric and teardown_persistent_fabric: + logger.info("Tearing down persistent fabric interface") + mesh_device.reset_sub_device_stall_group() + teardown_fabric_interface(mesh_device) + logger.info("Done tearing down persistent fabric interface") # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out) tt_output_tensor = ttnn.to_torch( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py index 7e054a2e409..c1673280601 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py @@ -25,25 +25,74 @@ ) +PREFETCHER_NOC1_RING = [ + (6, 6), + (6, 7), + (6, 9), + (6, 0), + (6, 1), + (6, 2), + (6, 4), + (6, 5), + (5, 5), + (5, 6), + (5, 7), + (5, 9), + (5, 0), + (5, 1), + (5, 2), + (5, 4), + (1, 4), + (1, 5), + (1, 9), + (1, 0), + (2, 0), + (2, 4), + (2, 5), + (2, 9), +] + + +def get_core_range_set(output_core_grid): + if isinstance(output_core_grid, ttnn.CoreGrid): + output_core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(output_core_grid.x - 1, output_core_grid.y - 1)), + ] + ) + else: + output_core_range_set = ttnn.CoreRangeSet( + [ + ttnn.CoreRange( + ttnn.CoreCoord(x, y), + ttnn.CoreCoord(x, y), + ) + for x, y in output_core_grid + ] + ) + return output_core_range_set + + # Enumerate the post-commit cases explicitly @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( "num_devices, num_links", [ + (4, 3), + (4, 2), (4, 1), ], - # [(4, 3), (4, 2)], Multi-links fails https://github.com/tenstorrent/tt-metal/issues/16699 ) @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, # hang?? - # ttnn.bfloat8_b, + ttnn.bfloat16, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize("shard_grid_orientation", [ttnn.ShardOrientation.ROW_MAJOR]) @pytest.mark.parametrize( - "tensor_mem_layout, output_shape, dim, input_shard_shape,shard_grid,layout", + "tensor_mem_layout, output_shape, dim, input_shard_shape,input_shard_grid,output_shard_shape, output_shard_grid, layout", ( ( # AllGather after SDPA (~160 us) ttnn.TensorMemoryLayout.HEIGHT_SHARDED, @@ -51,6 +100,13 @@ 1, (32, 128), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(7, 0))}), + (32, 128), + ttnn.CoreRangeSet( + { + ttnn.CoreRange(ttnn.CoreCoord(1, 0), ttnn.CoreCoord(3, 9)), + ttnn.CoreRange(ttnn.CoreCoord(5, 0), ttnn.CoreCoord(5, 1)), + } + ), ttnn.TILE_LAYOUT, ), ( # AllGather after Binary Mult+Silu (~160 us) @@ -59,6 +115,8 @@ 3, (32, 32), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(5, 4))}), + (32, 160), + get_core_range_set(PREFETCHER_NOC1_RING), ttnn.TILE_LAYOUT, ), ), @@ -66,12 +124,15 @@ @pytest.mark.parametrize("replication_factor", [8]) @pytest.mark.parametrize("enable_async", [True]) @pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 17068032}], indirect=True) def test_line_all_gather_sharded_on_TG_rows_llama( mesh_device, num_devices, output_shape, input_shard_shape, - shard_grid, + input_shard_grid, + output_shard_shape, + output_shard_grid, shard_grid_orientation, tensor_mem_layout, dim, @@ -82,23 +143,30 @@ def test_line_all_gather_sharded_on_TG_rows_llama( function_level_defaults, enable_async, replication_factor, - num_iters=10, + num_iters=100, ): if len(mesh_device.get_devices()) != 32: pytest.skip("Not TG!") input_shard_spec = ttnn.ShardSpec( - shard_grid, + input_shard_grid, input_shard_shape, shard_grid_orientation, ) - logger.warning("sharding not used due to issue #16699") + if output_shard_grid is not None and output_shard_shape is not None: + output_shard_spec = ttnn.ShardSpec( + output_shard_grid, + output_shard_shape, + shard_grid_orientation, + ) + else: + output_shard_spec = None run_line_all_gather_on_TG_with_mesh_tensor_along_rows( mesh_device, num_devices, output_shape, - ttnn.TensorMemoryLayout.INTERLEAVED, # tensor_mem_layout, + tensor_mem_layout, dim, num_links, input_dtype, @@ -108,9 +176,11 @@ def test_line_all_gather_sharded_on_TG_rows_llama( function_level_defaults, enable_async=enable_async, num_iters=num_iters, - # input_shard_spec=input_shard_spec, + input_shard_spec=input_shard_spec, + output_shard_spec=output_shard_spec, num_all_gather_instances=replication_factor, cluster_axis=1, + trace_mode=True, use_all_gather_async=True, enable_persistent_fabric=True, create_persistent_fabric=True, @@ -216,7 +286,7 @@ def test_line_reduce_scatter_sharded_on_TG_rows_llama( @pytest.mark.parametrize( "num_devices, num_links, per_chip_output_shape, layout", [ - (4, 2, [1, 1, 32, 1280], ttnn.TILE_LAYOUT), # AllReduce after QKV (~110 us) + (4, 1, [1, 1, 32, 1280], ttnn.TILE_LAYOUT), # AllReduce after QKV (~110 us) ], ) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py index 3b11f56b80a..08d359325c2 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -18,6 +18,11 @@ run_line_all_gather_on_TG_with_mesh_tensor_along_rows, ) +from tests.ttnn.unit_tests.operations.ccl.test_ccl_async_TG_llama import ( + PREFETCHER_NOC1_RING, + get_core_range_set, +) + def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): if layout == ttnn.ROW_MAJOR_LAYOUT and input_dtype == ttnn.bfloat8_b: @@ -133,7 +138,9 @@ def run_all_gather_impl( rand_tensor=True, mem_config=None, input_shard_shape=None, - shard_grid=None, + input_shard_grid=None, + output_shard_shape=None, + output_shard_grid=None, tensor_mem_layout=None, use_cluster_axis_api=False, cluster_axis=None, @@ -173,40 +180,56 @@ def run_all_gather_impl( mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - ccl_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + ccl_semaphore_handles = [ + create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) + ] logger.info(f"Output shape: {output_shape}") logger.info(f"dim: {dim}") logger.info(f"input_shard_shape: {input_shard_shape}") - logger.info(f"shard_grid: {shard_grid}") + logger.info(f"input_shard_grid: {input_shard_grid}") ### For sharded all gather only - if bool(input_shard_shape) != bool(shard_grid) and bool(tensor_mem_layout) != bool(shard_grid): + if bool(input_shard_shape) != bool(input_shard_grid) and bool(tensor_mem_layout) != bool(input_shard_grid): pytest.fail( "Both input_shard_shape, shard_grid, and tensor_mem_layout must be provided together or all must be None" ) - if input_shard_shape and shard_grid: + if input_shard_shape and input_shard_grid: input_shard_spec = ttnn.ShardSpec( - shard_grid, + input_shard_grid, input_shard_shape, ttnn.ShardOrientation.ROW_MAJOR, ) input_mem_config = ttnn.MemoryConfig( tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=input_shard_spec ) - output_shard_shape = list(input_shard_shape) - if dim == len(output_shape) - 1: - output_shard_shape[1] *= num_devices + if output_shard_shape is None: + assert ( + output_shard_grid is None + ), "output_shard_grid must not be provided if output_shard_shape is not provided" + output_shard_shape = list(input_shard_shape) + if dim == len(output_shape) - 1: + output_shard_shape[1] *= num_devices + else: + output_shard_shape[0] *= num_devices + output_shard_spec = ttnn.ShardSpec( + input_shard_grid, + output_shard_shape, + ttnn.ShardOrientation.ROW_MAJOR, + ) + output_mem_config = ttnn.MemoryConfig( + tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=output_shard_spec + ) else: - output_shard_shape[0] *= num_devices - output_shard_spec = ttnn.ShardSpec( - shard_grid, - output_shard_shape, - ttnn.ShardOrientation.ROW_MAJOR, - ) - output_mem_config = ttnn.MemoryConfig( - tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=output_shard_spec - ) + assert output_shard_grid is not None, "output_shard_grid must be provided if output_shard_shape is provided" + output_shard_spec = ttnn.ShardSpec( + output_shard_grid, + output_shard_shape, + ttnn.ShardOrientation.ROW_MAJOR, + ) + output_mem_config = ttnn.MemoryConfig( + tensor_mem_layout, buffer_type=ttnn.BufferType.L1, shard_spec=output_shard_spec + ) else: assert mem_config is not None input_mem_config = mem_config @@ -252,7 +275,7 @@ def run_all_gather_impl( num_links, output_mem_config, enable_persistent_fabric, - multi_device_global_semaphore=ccl_semaphore_handles, + multi_device_global_semaphore=ccl_semaphore_handles[0], num_iter=num_iters, subdevice_id=worker_sub_device_id, ) @@ -267,7 +290,7 @@ def run_all_gather_impl( mesh_device=mesh_device, memory_config=output_mem_config, topology=all_gather_topology, - multi_device_global_semaphore=ccl_semaphore_handles, + multi_device_global_semaphore=ccl_semaphore_handles[i], subdevice_id=worker_sub_device_id, enable_persistent_fabric_mode=enable_persistent_fabric, num_preferred_links=num_links, @@ -277,7 +300,7 @@ def run_all_gather_impl( tt_out_tensor = ttnn.experimental.all_gather_async( input_tensor_mesh_list[i], dim, - multi_device_global_semaphore=ccl_semaphore_handles, + multi_device_global_semaphore=ccl_semaphore_handles[i], num_links=num_links, memory_config=output_mem_config, topology=all_gather_topology, @@ -286,9 +309,9 @@ def run_all_gather_impl( ) tt_out_tensor_list.append(tt_out_tensor) - logger.info(f"Waiting for op {i}") - ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) - logger.info(f"Done iteration {i}") + logger.info(f"Waiting for op") + ttnn.synchronize_devices(mesh_device, sub_device_ids=sub_device_stall_group) + logger.info(f"Done op") passed = True for tensor_index in range(len(tt_out_tensor_list)): @@ -306,6 +329,12 @@ def run_all_gather_impl( logger.error(f"output mismatch for tensor {i}") passed = False + for i in range(num_devices): + assert ( + mesh_device.get_devices()[i].num_program_cache_entries() == 1 + or mesh_device.get_devices()[i].num_program_cache_entries() == num_iters + ), f"Device {i} has {mesh_device.get_devices()[i].num_program_cache_entries()} program cache entries" + if enable_persistent_fabric and teardown_persistent_fabric: mesh_device.reset_sub_device_stall_group() teardown_fabric_interface(mesh_device) @@ -319,7 +348,7 @@ def run_all_gather_impl( @pytest.mark.parametrize( "num_devices, num_links, output_shape, dim, layout", [ - (4, 1, [1, 1, 64, 512], 3, ttnn.TILE_LAYOUT), + # (4, 1, [1, 1, 64, 512], 3, ttnn.TILE_LAYOUT), # (4, 1, [1, 1, 32, 32768], 3, ttnn.TILE_LAYOUT), # (4, 1, [1, 1, 2048, 16384], 3, ttnn.TILE_LAYOUT), (4, 1, [1, 1, 32, 1280], 3, ttnn.TILE_LAYOUT), @@ -328,7 +357,7 @@ def run_all_gather_impl( @pytest.mark.parametrize( "input_dtype", [ - ttnn.bfloat16, + # ttnn.bfloat16, ttnn.bfloat8_b, ], ) @@ -338,7 +367,7 @@ def run_all_gather_impl( ttnn.MemoryConfig(buffer_type=ttnn.BufferType.DRAM), ], ) -@pytest.mark.parametrize("num_iters", [8]) +@pytest.mark.parametrize("num_iters", [10]) @pytest.mark.parametrize("enable_async", [True]) def test_all_gather( t3k_mesh_device, @@ -378,7 +407,7 @@ def test_all_gather( # Enumerate the post-commit cases explicitly @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "num_devices, output_shape, dim, layout, input_shard_shape, shard_grid, tensor_mem_layout", + "num_devices, output_shape, dim, layout, input_shard_shape, input_shard_grid, output_shard_shape, output_shard_grid, tensor_mem_layout", [ ( 2, @@ -387,6 +416,8 @@ def test_all_gather( ttnn.TILE_LAYOUT, (32, 32), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), + None, + None, ttnn.TensorMemoryLayout.WIDTH_SHARDED, ), ( @@ -396,6 +427,8 @@ def test_all_gather( ttnn.TILE_LAYOUT, (32, 64), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), + None, + None, ttnn.TensorMemoryLayout.WIDTH_SHARDED, ), ( @@ -405,6 +438,8 @@ def test_all_gather( ttnn.TILE_LAYOUT, (32, 128), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0))}), + None, + None, ttnn.TensorMemoryLayout.WIDTH_SHARDED, ), ( @@ -414,6 +449,8 @@ def test_all_gather( ttnn.TILE_LAYOUT, (32, 128), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 1))}), + None, + None, ttnn.TensorMemoryLayout.WIDTH_SHARDED, ), ( @@ -423,15 +460,8 @@ def test_all_gather( ttnn.TILE_LAYOUT, (32, 128), ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 3))}), - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ), - ( - 4, - [1, 4, 32, 1280], - 3, - ttnn.TILE_LAYOUT, - (32, 128), - ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(1, 4))}), + None, + None, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ), ], @@ -441,12 +471,13 @@ def test_all_gather( "input_dtype", [ ttnn.bfloat16, + ttnn.bfloat8_b, ], ) @pytest.mark.parametrize("num_iters", [8]) @pytest.mark.parametrize("enable_async", [True]) def test_all_gather_sharded( - pcie_mesh_device, + t3k_mesh_device, num_devices, output_shape, dim, @@ -458,11 +489,16 @@ def test_all_gather_sharded( function_level_defaults, enable_async, input_shard_shape, - shard_grid, + input_shard_grid, + output_shard_shape, + output_shard_grid, tensor_mem_layout, ): + if num_links > 1: + assert f"num_links > 1 not supported for sharded all gather test function which is currently using the t3k_mesh_device (and hence only has 1 link available for use)" + run_all_gather_impl( - pcie_mesh_device, + t3k_mesh_device, num_devices, output_shape, dim, @@ -476,7 +512,9 @@ def test_all_gather_sharded( enable_async=enable_async, rand_tensor=True, input_shard_shape=input_shard_shape, - shard_grid=shard_grid, + input_shard_grid=input_shard_grid, + output_shard_shape=output_shard_shape, + output_shard_grid=output_shard_grid, tensor_mem_layout=tensor_mem_layout, create_persistent_fabric=True, teardown_persistent_fabric=True, diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py index b9fa239e6d1..ddf202775b3 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py @@ -10,6 +10,8 @@ data_gen_with_range_dtype, compare_pcc, ) +from tests.ttnn.utils_for_testing import assert_with_pcc + from models.utility_functions import skip_for_grayskull, is_wormhole_b0, is_blackhole @@ -838,27 +840,30 @@ def test_unary_softshrink(input_shapes, param, device): assert comp_pass -@pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize( "input_shapes", ( (torch.Size([1, 1, 32, 32])), (torch.Size([1, 1, 320, 384])), (torch.Size([1, 3, 320, 384])), + (torch.Size([7, 185, 20])), + (torch.Size([6, 45, 233])), ), ) @pytest.mark.parametrize( "param", - {-1e4, -98.5, -43.7, -8.5, 0.45, 7.7, 58.4, 89.9, 1e5}, + {-1e4, -98.5, -43.7, -8.5, 0.0, 0.45, 1.0, 7.7, 58.4, 89.9, 1e5}, ) def test_unary_logit(input_shapes, param, device): - in_data, input_tensor = data_gen_with_range(input_shapes, 0, 1, device) + in_data = torch.Tensor(size=input_shapes).uniform_(-100, 100).to(torch.bfloat16) + input_tensor = ttnn.from_torch(in_data, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.logit(input_tensor, eps=param) golden_function = ttnn.get_golden_function(ttnn.logit) golden_tensor = golden_function(in_data, eps=param, device=device) - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass + out = ttnn.to_torch(output_tensor) + assert_with_pcc(golden_tensor, out, 0.99) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/test_core.py b/tests/ttnn/unit_tests/operations/test_core.py index c39154379df..57709827f07 100644 --- a/tests/ttnn/unit_tests/operations/test_core.py +++ b/tests/ttnn/unit_tests/operations/test_core.py @@ -527,8 +527,9 @@ def test_bh_alignment_i2s( memory_config=input_buffer_type, dtype=ttnn.bfloat16, ) - x_t_sharded = ttnn.to_memory_config(x_t, shard_config) - x_t = ttnn.to_memory_config(x_t_sharded, output_buffer_type) + # So far the sharded tensor alignment is controled by keep_l1_aligned flag, will remove it later after launch + x_t_sharded = ttnn.interleaved_to_sharded(x_t, shard_config, keep_l1_aligned=True) + x_t = ttnn.sharded_to_interleaved(x_t_sharded, output_buffer_type, is_l1_aligned=True) output_data = ttnn.from_device(x_t) output_data = ttnn.to_torch(output_data) passing = torch.equal(input_data, output_data) diff --git a/tests/ttnn/unit_tests/operations/test_group_norm.py b/tests/ttnn/unit_tests/operations/test_group_norm.py index 57441a6f047..2ae8848ee2f 100644 --- a/tests/ttnn/unit_tests/operations/test_group_norm.py +++ b/tests/ttnn/unit_tests/operations/test_group_norm.py @@ -292,7 +292,7 @@ def test_group_norm_with_block_sharded_v2_8x8_grid(device, N, C, H, W, num_group sharded_mem_config = ttnn.MemoryConfig( ttnn.types.TensorMemoryLayout.BLOCK_SHARDED, ttnn.types.BufferType.L1, shard_spec ) - input_tensor = ttnn.to_memory_config(input_tensor, sharded_mem_config) + input_tensor = ttnn.interleaved_to_sharded(input_tensor, sharded_mem_config, keep_l1_aligned=True) # groupnorm output_tensor = ttnn.group_norm( @@ -306,7 +306,7 @@ def test_group_norm_with_block_sharded_v2_8x8_grid(device, N, C, H, W, num_group ) # output tensor - output_tensor = ttnn.to_memory_config(output_tensor, ttnn.L1_MEMORY_CONFIG) + output_tensor = ttnn.sharded_to_interleaved(output_tensor, ttnn.L1_MEMORY_CONFIG, is_l1_aligned=True) output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) diff --git a/tests/ttnn/unit_tests/operations/test_linear.py b/tests/ttnn/unit_tests/operations/test_linear.py index 904c3a1af65..9f77989e3ec 100644 --- a/tests/ttnn/unit_tests/operations/test_linear.py +++ b/tests/ttnn/unit_tests/operations/test_linear.py @@ -341,3 +341,73 @@ def test_linear_with_fp32_dest_acc_and_bias(device): ) output_tensor = ttnn.to_torch(output1) assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99) + + +def test_resnet50_linear(device, use_program_cache): + torch.manual_seed(0) + batch_size = 16 + input_channels = 2048 + output_channels = 1000 + input_shape = [1, 1, batch_size, input_channels] + torch_input_tensor = torch.randn(input_shape, dtype=torch.bfloat16) + torch_weight_tensor = torch.randn([1, 1, output_channels, input_channels], dtype=torch.bfloat16) + torch_bias_tensor = torch.randn([1, 1, 1, output_channels], dtype=torch.bfloat16) + torch_out_golden_tensor = torch.nn.functional.linear( + torch_input_tensor[0, 0, :, :], torch_weight_tensor[0, 0, :, :], bias=torch_bias_tensor[0, 0, :, :] + ) + + tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat8_b, device=device, layout=ttnn.TILE_LAYOUT) + tt_weight_tensor = ttnn.from_torch( + torch.permute(torch_weight_tensor, (0, 1, 3, 2)), ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT + ) + tt_bias_tensor = ttnn.from_torch(torch_bias_tensor, ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT) + + compute_config = ttnn.init_device_compute_kernel_config( + device.arch(), + math_fidelity=ttnn.MathFidelity.LoFi, + math_approx_mode=True, + fp32_dest_acc_en=False, + packer_l1_acc=True, + ) + matmul_config = ttnn.MatmulMultiCoreReuseMultiCast1DProgramConfig( + compute_with_storage_grid_size=(8, 4), + in0_block_w=2, + out_subblock_h=1, + out_subblock_w=1, + per_core_M=1, + per_core_N=1, + fuse_batch=True, + fused_activation=None, + mcast_in0=True, + ) + + grid_size = (8, 4) + shard_grid = ttnn.CoreRangeSet( + { + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(grid_size[0] - 1, grid_size[1] - 1), + ) + } + ) + x = tt_input_tensor + shard_shape = [ + x.volume() // x.padded_shape[-1], + x.padded_shape[-1] // (grid_size[0] * grid_size[1]), + ] + shard_spec = ttnn.ShardSpec(shard_grid, shard_shape, ttnn.ShardOrientation.ROW_MAJOR) + width_sharded_mem_config = ttnn.MemoryConfig(ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.L1, shard_spec) + x = ttnn.to_memory_config(x, width_sharded_mem_config) + + tt_output_tensor_on_device = ttnn.linear( + x, + tt_weight_tensor, + bias=tt_bias_tensor, + program_config=matmul_config, + memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, + dtype=ttnn.bfloat8_b, + compute_kernel_config=compute_config, + ) + tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) + torch_output_tensor = ttnn.to_torch(tt_output_tensor) + assert_with_pcc(torch_out_golden_tensor, torch_output_tensor[0, 0, :, :], pcc=0.99) diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 1d67d1294de..c3f02edef65 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -14,7 +14,10 @@ ) from tests.ttnn.utils_for_testing import assert_with_pcc, check_with_pcc, check_with_pcc_without_tensor_printout import ttnn -from tests.ttnn.ttnn_utility_fuction import get_shard_grid_from_num_cores + +HS = ttnn.TensorMemoryLayout.HEIGHT_SHARDED +BS = ttnn.TensorMemoryLayout.BLOCK_SHARDED +WS = ttnn.TensorMemoryLayout.WIDTH_SHARDED def run_conv( @@ -33,13 +36,9 @@ def run_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, dilation=1, use_shallow_conv_variant=False, - transpose_mcast=True, - enable_auto_formatting=False, - padded_input_channels=None, fp32_accum=False, packer_l1_acc=False, output_layout=ttnn.TILE_LAYOUT, @@ -83,12 +82,6 @@ def run_conv( dilation=(dilation, dilation), groups=groups, ) - output_shape_nhwc = [ - torch_out_golden_tensor.shape[0], - torch_out_golden_tensor.shape[2], - torch_out_golden_tensor.shape[3], - torch_out_golden_tensor.shape[1], - ] reader_patterns_cache = {} @@ -105,16 +98,16 @@ def run_conv( mesh_mapper=weight_mesh_mapper, ) - tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16, mesh_mapper=input_mesh_mapper) + tt_input_tensor = ttnn.from_torch( + torch_input_tensor, + activations_dtype if activations_dtype == ttnn.float32 else ttnn.bfloat16, + mesh_mapper=input_mesh_mapper, + ) - if shard_layout is None and not auto_shard: - shard_layout = ( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED - ) conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - shard_layout=shard_layout, + shard_layout=shard_layout if not auto_shard else None, input_channels_alignment=( 16 if use_shallow_conv_variant or (input_channels == 16 and input_height == 115) else 32 ), @@ -142,7 +135,7 @@ def run_conv( conv_config.override_sharding_config = True print("Setting num_cores_nhw to 98") - [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -162,7 +155,6 @@ def run_conv( debug=debug, groups=groups, memory_config=memory_config, - return_weights_and_bias=True, return_output_dim=True, ) @@ -212,11 +204,12 @@ def run_conv_with_split( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=None, split_factor=2, fp32_accum=False, packer_l1_acc=False, + auto_shard=False, ): torch.manual_seed(0) assert input_channels % split_factor == 0 @@ -238,32 +231,13 @@ def run_conv_with_split( split_input_tensors = torch.split(torch_input_tensor_nchw, split_input_channels, 1) split_weight_tensors = torch.split(torch_weight_tensor, split_input_channels, 1) - # conv1_output_tensor = torch.nn.functional.conv2d( - # split_input_tensors[0], - # split_weight_tensors[0], - # bias=torch_bias_tensor.reshape(-1), - # stride=(stride_h, stride_w), - # padding=(pad_h, pad_w), - # ) - # conv2_output_tensor = torch.nn.functional.conv2d( - # split_input_tensors[1], - # split_weight_tensors[1], - # stride=(stride_h, stride_w), - # padding=(pad_h, pad_w), - # ) - # torch_output_tensor = torch.add(conv1_output_tensor, conv2_output_tensor) - - torch_input1_tensor = torch.permute(split_input_tensors[0], (0, 2, 3, 1)) - torch_input2_tensor = torch.permute(split_input_tensors[1], (0, 2, 3, 1)) + reader_patterns_cache = {} - shard_layout = ( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED - ) conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, - shard_layout=shard_layout if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=shard_layout if not auto_shard else None, # input_channels_alignment=(16 if use_shallow_conv_variant else 32), ) compute_config = ttnn.init_device_compute_kernel_config( @@ -292,7 +266,7 @@ def run_conv_with_split( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) # tt_input_tensor_on_device = convs[i].copy_input_to_device(tt_input_tensor) # tt_output_tensor_on_device = convs[i](tt_input_tensor_on_device) - [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=split_input_channels, @@ -309,7 +283,6 @@ def run_conv_with_split( compute_config=compute_config, conv_op_cache=reader_patterns_cache, return_output_dim=True, - return_weights_and_bias=True, ) tt_conv_output_tensor = ttnn.from_device(tt_output_tensor_on_device) torch_conv_output_tensor = ttnn.to_torch(tt_conv_output_tensor) @@ -338,14 +311,14 @@ def run_conv_with_split( @pytest.mark.parametrize( "output_channels, input_channels, input_height, input_width, shard_layout, config", ( - (256, 256, 8, 8, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None), - (128, 128, 32, 32, ttnn.TensorMemoryLayout.BLOCK_SHARDED, None), - (16, 16, 256, 256, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, {"act_block_h": 32}), + (256, 256, 8, 8, WS, None), + (128, 128, 32, 32, BS, None), + (16, 16, 256, 256, HS, {"act_block_h": 32}), ), ) @pytest.mark.parametrize( "weights_dtype", - [ttnn.bfloat8_b, ttnn.bfloat16], + [ttnn.bfloat16], ) @pytest.mark.parametrize( "activations_dtype", @@ -357,7 +330,7 @@ def run_conv_with_split( ) @pytest.mark.parametrize( "packer_l1_acc", - [True, False], + [False], ) @pytest.mark.parametrize( "filter, pad", @@ -367,7 +340,7 @@ def run_conv_with_split( [5, 2], ], ) -@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi, ttnn.MathFidelity.HiFi4]) +@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) def test_conv_features( device, @@ -411,7 +384,6 @@ def test_conv_features( stride, pad, pad, - True, config, shard_layout=shard_layout, output_layout=output_layout, @@ -429,9 +401,9 @@ def test_conv_features( @pytest.mark.parametrize( "output_channels, input_channels, input_height, input_width, shard_layout, config", ( - (256, 256, 8, 8, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None), - (128, 128, 32, 32, ttnn.TensorMemoryLayout.BLOCK_SHARDED, None), - (16, 16, 256, 256, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, {"act_block_h": 32}), + (256, 256, 8, 8, WS, None), + (128, 128, 32, 32, BS, None), + (16, 16, 256, 256, HS, {"act_block_h": 32}), ), ) @pytest.mark.parametrize( @@ -488,7 +460,6 @@ def test_conv_features_multi_device( stride, pad, pad, - True, config, shard_layout=shard_layout, output_layout=output_layout, @@ -502,7 +473,7 @@ def test_conv_features_multi_device( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) -@pytest.mark.parametrize("stride", [1, 2]) +@pytest.mark.parametrize("stride", [1]) @pytest.mark.parametrize( "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, pad_h, pad_w, act_block_w_div", ( @@ -526,7 +497,7 @@ def test_conv_features_multi_device( ) @pytest.mark.parametrize( "weights_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], + [ttnn.bfloat16], ) @pytest.mark.parametrize( "activations_dtype", @@ -626,7 +597,7 @@ def test_conv_ws( fp32_dest_acc_en=fp32_accum, packer_l1_acc=packer_l1_acc, ) - [tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device]] = ttnn.conv2d( + [tt_output_tensor_on_device, [out_height, out_width]] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -645,7 +616,6 @@ def test_conv_ws( debug=debug, groups=groups, return_output_dim=True, - return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -669,22 +639,29 @@ def test_conv_ws( @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, shard_layout, config_override, use_shallow_conv_variant", ( # mlp sub_module - (1, 3, 32, 512, 512, 7, 7, 4, 4, 3, 3, 1, True, {"act_block_h": 64}, False), # ncrisc build failed + (1, 3, 32, 512, 512, 7, 7, 4, 4, 3, 3, 1, HS, {"act_block_h": 64}, False), # ncrisc build failed # efficient selfattention sub_module - (1, 32, 32, 128, 128, 8, 8, 8, 8, 0, 0, 1, True, None, False), # ncrisc build failed, Two times called in model - (1, 64, 64, 64, 64, 4, 4, 4, 4, 0, 0, 1, True, None, False), # ncrisc build failed, Two times called in model - (1, 160, 160, 32, 32, 2, 2, 2, 2, 0, 0, 1, True, None, False), # pass , Two times called in model + (1, 32, 32, 128, 128, 8, 8, 8, 8, 0, 0, 1, HS, None, False), # ncrisc build failed, Two times called in model + (1, 64, 64, 64, 64, 4, 4, 4, 4, 0, 0, 1, HS, None, False), # ncrisc build failed, Two times called in model + (1, 160, 160, 32, 32, 2, 2, 2, 2, 0, 0, 1, HS, None, False), # pass , Two times called in model # dwconv sub_module - (1, 128, 128, 128, 128, 3, 3, 1, 1, 1, 1, 128, True, {"act_block_h": 64}, False), - (1, 256, 256, 64, 64, 3, 3, 1, 1, 1, 1, 256, True, None, False), # pass , Two times called in model - (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 640, False, {"act_block_h": 32}, False), - # (1,1024, 1024, 16, 16, 3, 3, 1, 1, 1, 1, 1024, False, None, False), #Switch to Width Sharding + (1, 128, 128, 128, 128, 3, 3, 1, 1, 1, 1, 128, HS, {"act_block_h": 64}, False), + (1, 256, 256, 64, 64, 3, 3, 1, 1, 1, 1, 256, HS, None, False), # pass , Two times called in model + (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, 640, ttnn.TensorMemoryLayout.BLOCK_SHARDED, {"act_block_h": 32}, False), + # (1,1024, 1024, 16, 16, 3, 3, 1, 1, 1, 1, 1024, BS, None, False), #Switch to Width Sharding # decode_head sub_module - # (1,1024, 256, 128, 128, 1, 1, 1, 1, 0, 0, 1, False, {"act_block_h": 32}, False), #pass for activation_dtype=bf8 but fails for bf16 - (1, 256, 150, 128, 128, 1, 1, 1, 1, 0, 0, 1, True, None, False), + # (1,1024, 256, 128, 128, 1, 1, 1, 1, 0, 0, 1, BS, {"act_block_h": 32}, False), #pass for activation_dtype=bf8 but fails for bf16 + (1, 256, 150, 128, 128, 1, 1, 1, 1, 0, 0, 1, HS, None, False), + (1, 32, 16, 64, 64, 1, 1, 1, 1, 0, 0, 1, HS, None, False), + (1, 96, 24, 32, 32, 1, 1, 1, 1, 0, 0, 1, HS, None, False), + (1, 576, 576, 8, 8, 3, 3, 1, 1, 0, 0, 576, WS, None, False), + (1, 576, 576, 8, 8, 3, 3, 2, 2, 0, 0, 576, WS, None, False), + (1, 960, 960, 4, 4, 3, 3, 1, 1, 0, 0, 960, WS, None, False), + (1, 144, 24, 32, 32, 1, 1, 1, 1, 0, 0, 1, HS, None, False), + (1, 144, 32, 16, 16, 1, 1, 1, 1, 0, 0, 1, HS, None, False), ), ) @pytest.mark.parametrize( @@ -716,7 +693,7 @@ def test_conv_for_segformer_512x512( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, groups, @@ -739,39 +716,39 @@ def test_conv_for_segformer_512x512( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, use_shallow_conv_variant=use_shallow_conv_variant, groups=groups, output_layout=output_layout, has_bias=False, auto_shard=auto_shard, + shard_layout=shard_layout, ) @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="This is test is for Grayskull only. Skipping") @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( # unique convs in rn50 (complete list) # first conv post folding and input_channels padding to tile width # (64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True), act_block_h_ntiles % 2 == 0 # rn50 layer1 - (64, 64, 56, 56, 1, 1, 1, 1, 0, 0, True, None), - (64, 64, 56, 56, 1, 1, 2, 2, 0, 0, True, None), - (64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (64, 64, 56, 56, 1, 1, 1, 1, 0, 0, HS, None), + (64, 64, 56, 56, 1, 1, 2, 2, 0, 0, HS, None), + (64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), # rn50 layer2 - (128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + (128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), # rn50 layer3 - (256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - (256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), + (256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + (256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), # rn50 layer4 - (512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - (512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + (512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + (512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), ## 1x1s2 - # (256, 256, 28, 28, 1, 1, 2, 2, 0, 0, True, {"num_cores_nhw": 98}), + # (256, 256, 28, 28, 1, 1, 2, 2, 0, 0, HS, {"num_cores_nhw": 98}), ), ) @pytest.mark.parametrize( @@ -805,7 +782,7 @@ def test_resnet50_conv_gs( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, auto_shard, ): @@ -842,78 +819,77 @@ def test_resnet50_conv_gs( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override=config_override, use_shallow_conv_variant=input_channels == 16, - padded_input_channels=16 if input_channels == 16 else None, debug=not (batch_size == 20 and input_height == 115), auto_shard=auto_shard, + shard_layout=shard_layout, ) @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( # unique convs in rn50 (complete list) # first conv post folding and input_channels padding to tile width # (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, None), HANGS!! - (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 256}), - # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 32}), Out of Memory!! + (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, {"act_block_h": 256}), + # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, {"act_block_h": 32}), Out of Memory!! # rn50 layer1 - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), - (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), - (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), + (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), + (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), # rn50 layer2 - (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, {"act_block_h": 32}), - (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, {"act_block_h": 32}), + (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), # rn50 layer3 - (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - (16, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - (20, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), + (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), + (16, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), + (20, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), # rn50 layer4 - (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - (20, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - (16, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - (20, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + (20, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), + (16, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), + (20, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), ## small test - (1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 2, "grid_size": (2, 2)}), - (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, {"num_cores_nhw": 4, "grid_size": (2, 4)}), - # (1, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, False, None), sliding_window_op_infra/sliding_window.cpp:341: indices_length_last_core <= indices_length_per_core - (8, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 64, 8, 8, 3, 3, 1, 1, 1, 1, BS, {"num_cores_nhw": 2, "grid_size": (2, 2)}), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, {"num_cores_nhw": 4, "grid_size": (2, 4)}), + # (1, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), sliding_window_op_infra/sliding_window.cpp:341: indices_length_last_core <= indices_length_per_core + (8, 256, 256, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), # r50 1x1s2 shapes - # Fails with packer_l1_acc = True (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 first bottleneck downsample shape - (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, True, None), # r50 first bottleneck downsample shape - # Fails with packer_l1_acc = True (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, False, None), # r50 second bottleneck downsample shape - # (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit - (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, False, None), # r50 third bottleneck downsample shape - # (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit - (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, False, None), # r50 fourth bottleneck downsample shape - # (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, True, None), - doesnt fit - # (20, 128, 256, 56, 56, 1, 1, 2, 2, 0, 0, True, None), ## L2M1 DS: doesn't fit + # Fails with packer_l1_acc = True (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, BS, None), # r50 first bottleneck downsample shape + (20, 256, 64, 56, 56, 1, 1, 2, 2, 0, 0, HS, None), # r50 first bottleneck downsample shape + # Fails with packer_l1_acc = True (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, BS, None), # r50 second bottleneck downsample shape + # (20, 512, 256, 56, 56, 1, 1, 2, 2, 0, 0, HS, None), - doesnt fit + (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, BS, None), # r50 third bottleneck downsample shape + # (20, 1024, 512, 28, 28, 1, 1, 2, 2, 0, 0, HS, None), - doesnt fit + (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, BS, None), # r50 fourth bottleneck downsample shape + # (20, 2048, 1024, 14, 14, 1, 1, 2, 2, 0, 0, HS, None), - doesnt fit + # (20, 128, 256, 56, 56, 1, 1, 2, 2, 0, 0, HS, None), ## L2M1 DS: doesn't fit ), ) @pytest.mark.parametrize( "weights_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], + [ttnn.bfloat8_b], ) @pytest.mark.parametrize( "activations_dtype", [ttnn.bfloat16, ttnn.bfloat8_b], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) -@pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("packer_l1_acc", [True]) +@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_resnet50_conv_wh( device, @@ -932,7 +908,7 @@ def test_resnet50_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, packer_l1_acc, has_bias, @@ -940,27 +916,8 @@ def test_resnet50_conv_wh( ): if device.core_grid.y == 7: pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") - if batch_size > 8 and (activations_dtype != ttnn.bfloat8_b or weights_dtype != ttnn.bfloat8_b): - pytest.skip("Batch > 8 must be run fully bfp8") - if ( - ( - activations_dtype == ttnn.bfloat16 - and batch_size == 20 - and ( - output_channels == 64 - or ( - stride_h == 2 - and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) - ) - ) - ) - # packer l1 acc has separate buffers when interm != output df, cannot fit into L1 - or (batch_size == 20 and activations_dtype == ttnn.bfloat8_b and packer_l1_acc and input_height >= 64) - ): - pytest.skip("Skipping test because it won't fit in L1!") - - use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0 + use_shallow_conv_variant = (input_channels == 16) and device.arch() == ttnn.device.Arch.GRAYSKULL run_conv( device, math_fidelity, @@ -977,24 +934,23 @@ def test_resnet50_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override=config_override, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH packer_l1_acc=packer_l1_acc, fp32_accum=False, has_bias=has_bias, auto_shard=auto_shard, + shard_layout=shard_layout, ) @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( - (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 256}), - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, {"act_block_h": 256}), + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), ), ) @pytest.mark.parametrize("memory_config", [ttnn.L1_MEMORY_CONFIG, ttnn.DRAM_MEMORY_CONFIG]) @@ -1012,7 +968,7 @@ def test_conv_mem_config_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, memory_config, ): @@ -1036,10 +992,9 @@ def test_conv_mem_config_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout=shard_layout, config_override=config_override, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH packer_l1_acc=True, fp32_accum=False, has_bias=True, @@ -1051,38 +1006,38 @@ def test_conv_mem_config_wh( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( # unique convs in rn50 (complete list) # first conv post folding and input_channels padding to tile width - # (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, None), - # (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 32}), - # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, True, {"act_block_h": 32}), + # (8, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, None), + # (16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, {"act_block_h": 32}), + # (20, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, {"act_block_h": 32}), # rn50 layer1 - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), - # (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), - # (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), + # (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), + # (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), # # rn50 layer2 - (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - # (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - # (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, {"act_block_h": 32}), - (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - # (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - # (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + # (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + # (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, {"act_block_h": 32}), + (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + # (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + # (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), # # rn50 layer3 - # (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - # (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - # (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, False, None), - # (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - # (16, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), - # (20, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, False, None), + # (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + # (16, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + # (20, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, BS, None), + # (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), + # (16, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), + # (20, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, BS, None), # # rn50 layer4 - # (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - # (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - # (20, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, False, None), - # (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - # (16, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), - # (20, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, False, None), + # (8, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + # (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + # (20, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, BS, None), + # (8, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), + # (16, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), + # (20, 512, 512, 7, 7, 3, 3, 1, 1, 1, 1, BS, None), ), ) @pytest.mark.parametrize( @@ -1101,7 +1056,7 @@ def test_conv_mem_config_wh( ], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4]) -@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) +@pytest.mark.parametrize("packer_l1_acc", [True]) @pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_resnet50_conv_wh_fp32( device, @@ -1121,7 +1076,7 @@ def test_resnet50_conv_wh_fp32( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, packer_l1_acc, auto_shard, @@ -1159,67 +1114,66 @@ def test_resnet50_conv_wh_fp32( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout=shard_layout, config_override=config_override, use_shallow_conv_variant=use_shallow_conv_variant, fp32_accum=fp32_accum, packer_l1_acc=packer_l1_acc, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH auto_shard=auto_shard, ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( # sd convs with HxW=32x32 - # (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, False, None), - # (1, 320, 320, 32, 32, 3, 3, 2, 2, 1, 1, False, None), - # (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False, None), - # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 activations doesnt fit - # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, None), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit - # (1, 1280, 1280, 8, 8, 3, 3, 2, 2, 1, 1, False, None), #fails to parallelize with sharding - # (1, 1280, 1280, 4, 4, 3, 3, 1, 1, 1, 1, False, None), #fails to parallelize with sharding - # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit - # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, None), # doesnt fit at all.. for all data types + # (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, BS, None), + # (1, 320, 320, 32, 32, 3, 3, 2, 2, 1, 1, BS, None), + # (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, BS, None), + # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 activations doesnt fit + # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit + # (1, 1280, 1280, 8, 8, 3, 3, 2, 2, 1, 1, BS, None), #fails to parallelize with sharding + # (1, 1280, 1280, 4, 4, 3, 3, 1, 1, 1, 1, BS, None), #fails to parallelize with sharding + # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit + # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, BS, None), # doesnt fit at all.. for all data types # sd convs with HxW=64x64 with batch size = 1 - (1, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, True, None), - (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit - (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), - (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # - (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit - (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 weights doesnt fit - (1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit. - (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 weights doesnt fit - # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - (1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - # (1, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, False, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (1, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, False, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + (1, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, HS, None), + (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), # bfloat16 doesnt fit + (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, BS, None), + (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), # + (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 doesnt fit + (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # bfloat16 weights doesnt fit + (1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 doesnt fit. + (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), # bfloat16 weights doesnt fit + # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, BS, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + (1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + # (1, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (1, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) # sd convs with HxW=64x64 with batch size=2 - # (2, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, True, None), Hangs on WH - (2, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), - (2, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), # fits with bfloat8_b - (2, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), - (2, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit - (2, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 doesnt fit - (2, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit - (2, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - # (2, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - (2, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), - # (2, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, False, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, False, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 1280, 1920, 16, 16, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 640, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) - # (2, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, HS, None), Hangs on WH + (2, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 64}), + (2, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, BS, None), # fits with bfloat8_b + (2, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 64}), + (2, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 doesnt fit + (2, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # bfloat16 doesnt fit + (2, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, BS, {"act_block_h": 32}), # bfloat16 doesnt fit + (2, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + # (2, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + (2, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 64}), + # (2, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 1280, 1920, 16, 16, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 640, 1280, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) + # (2, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), IndexError: vector::_M_range_check: __n (which is 1) >= this->size() (which is 1) # 1x1 conv - (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, True, None), + (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, HS, None), # Small conv - # (1, 32, 32, 16, 16, 3, 3, 2, 2, 1, 1, True, None), ## batch = 1 is currently not supported + # (1, 32, 32, 16, 16, 3, 3, 2, 2, 1, 1, HS, None), ## batch = 1 is currently not supported ), ) @pytest.mark.parametrize( @@ -1251,7 +1205,7 @@ def test_sd_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, enable_auto_formatting, auto_shard, @@ -1275,10 +1229,9 @@ def test_sd_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, split_factor=3 if input_channels == 1920 else 2, - auto_shard=auto_shard, ) else: run_conv( @@ -1297,11 +1250,9 @@ def test_sd_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=(input_channels == 16), - enable_auto_formatting=enable_auto_formatting, - padded_input_channels=16 if input_channels == 16 else None, auto_shard=auto_shard, ) @@ -1309,60 +1260,60 @@ def test_sd_conv( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( # sd convs with HxW=32x32 - # (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, False, None), - # (1, 320, 320, 32, 32, 3, 3, 2, 2, 1, 1, False, None), - # (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False, None), - # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 activations doesnt fit - # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, None), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit - # (1, 1280, 1280, 8, 8, 3, 3, 2, 2, 1, 1, False, None), #fails to parallelize with sharding - # (1, 1280, 1280, 4, 4, 3, 3, 1, 1, 1, 1, False, None), #fails to parallelize with sharding - # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit - # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, None), # doesnt fit at all.. for all data types + # (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, BS, None), + # (1, 320, 320, 32, 32, 3, 3, 2, 2, 1, 1, BS, None), + # (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, BS, None), + # (1, 640, 640, 16, 16, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 activations doesnt fit + # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), # slighlty low pcc with 0.99689. bfloat16 weights doesnt fit + # (1, 1280, 1280, 8, 8, 3, 3, 2, 2, 1, 1, BS, None), #fails to parallelize with sharding + # (1, 1280, 1280, 4, 4, 3, 3, 1, 1, 1, 1, BS, None), #fails to parallelize with sharding + # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # slightly low pcc with 0.99698. bfloat16 weights doesnt fit + # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, BS, None), # doesnt fit at all.. for all data types # sd convs with HxW=64x64 with batch size = 1 - # (1, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, True, None), - # (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit - # (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), - # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # - # (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit - # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 weights doesnt fit - # (1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit. - # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 weights doesnt fit - # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, None), - # (1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - # (1, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, False, None), - # (1, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + # (1, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, HS, None), + # (1, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), # bfloat16 doesnt fit + # (1, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, BS, None), + # (1, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), # + # (1, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 doesnt fit + # (1, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # bfloat16 weights doesnt fit + # (1, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 doesnt fit. + # (1, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), # bfloat16 weights doesnt fit + # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, BS, None), + # (1, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + # (1, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), + # (1, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # # sd convs with HxW=64x64 with batch size=2 - (2, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, True, None), - (2, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 64}), - (2, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, False, None), # fits with bfloat8_b - (2, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, False, None), # bfloat16 doesnt fit - (2, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, False, None), # bfloat16 doesnt fit - (2, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit - (2, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), # bfloat16 doesnt fit - # (2, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), L1 Allocation Error - (2, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, False, None), - (2, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (2, 1280, 1920, 16, 16, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 640, 1280, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), - (2, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}), + (2, 320, 16, 64, 64, 3, 3, 1, 1, 1, 1, HS, None), + (2, 320, 320, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 64}), + (2, 320, 320, 64, 64, 3, 3, 2, 2, 1, 1, BS, None), # fits with bfloat8_b + (2, 640, 640, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 640, 640, 32, 32, 3, 3, 2, 2, 1, 1, BS, None), # bfloat16 doesnt fit + (2, 1280, 1280, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), # bfloat16 doesnt fit + (2, 1280, 1280, 16, 16, 3, 3, 2, 2, 1, 1, BS, {"act_block_h": 32}), # bfloat16 doesnt fit + (2, 1280, 1280, 8, 8, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), # bfloat16 doesnt fit + # (2, 640, 640, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), L1 Allocation Error + (2, 1280, 2560, 8, 8, 3, 3, 1, 1, 1, 1, BS, None), + (2, 1280, 2560, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (2, 1280, 1920, 16, 16, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 640, 1920, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 640, 1280, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 640, 960, 32, 32, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 320, 960, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), + (2, 320, 640, 64, 64, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}), # 1x1 conv - (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, True, None), + (2, 320, 960, 64, 64, 1, 1, 1, 1, 0, 0, HS, None), # Small conv - # (1, 32, 32, 16, 16, 3, 3, 2, 2, 1, 1, True, None), fails + # (1, 32, 32, 16, 16, 3, 3, 2, 2, 1, 1, HS, None), fails ), ) @pytest.mark.parametrize( "weights_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], + [ttnn.bfloat8_b], ) @pytest.mark.parametrize( "activations_dtype", @@ -1372,11 +1323,9 @@ def test_sd_conv( "fp32_accum", [ False, - True, ], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -@pytest.mark.parametrize("enable_auto_formatting", [True, False]) def test_sd_conv_wh( device, use_program_cache, @@ -1395,9 +1344,8 @@ def test_sd_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, - enable_auto_formatting, ): if device.core_grid.y == 7: pytest.skip("This test is not supported for N300") @@ -1411,14 +1359,11 @@ def test_sd_conv_wh( and input_height == 32 and activations_dtype == ttnn.bfloat16 and weights_dtype == ttnn.bfloat16 - and enable_auto_formatting == False ) ): pytest.skip("Skip the test cases raising OOM but not affecting e2e test") if filter_height > 1 and (input_channels > 1280 or (input_channels > 640 and input_height > 16)): - if enable_auto_formatting: - pytest.skip("Not running split SD conv with auto formatting") run_conv_with_split( device, math_fidelity, @@ -1435,8 +1380,8 @@ def test_sd_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, split_factor=3 if input_channels == 1920 else 2, fp32_accum=fp32_accum, packer_l1_acc=True, @@ -1458,12 +1403,9 @@ def test_sd_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=(input_channels == 16), - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH - enable_auto_formatting=enable_auto_formatting, - padded_input_channels=16 if input_channels == 16 else None, fp32_accum=fp32_accum, packer_l1_acc=True, ) @@ -1472,22 +1414,23 @@ def test_sd_conv_wh( @pytest.mark.skipif(is_wormhole_b0() or is_blackhole(), reason="Unsupported on WH and BH") @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant", ( # unet convs with batch size 2 # unique convs in unet (complete list) - (2, 16, 3, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 64}, True), - (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 64}, True), - (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 16, 264, 40, 3, 3, 1, 1, 1, 1, True, None, True), - (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, True), - (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 64, 64, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 64, 264, 40, 3, 3, 1, 1, 1, 1, True, None, True), - (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, True), + (2, 16, 3, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 64}, True), + (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 64}, True), + (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 16, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, True), + (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, True), + (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 64, 64, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 64, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, True), + (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, True), + # fails. mismatch. It passes when input_channels=64. Probably an issue with padding when input_channels % 32 != 0. ( 2, 16, @@ -1500,14 +1443,14 @@ def test_sd_conv_wh( 1, 1, 1, - True, + HS, {"act_block_h": 32}, False, - ), # fails. mismatch. It passes when input_channels=64. Probably an issue with padding when input_channels % 32 != 0. - (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), - (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), - (2, 1, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, False), + ), + (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 8 * 32}, False), + (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 8 * 32}, False), + (2, 1, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 8 * 32}, False), ), ) @pytest.mark.parametrize( @@ -1520,7 +1463,7 @@ def test_sd_conv_wh( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) -@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) +@pytest.mark.parametrize("auto_shard", [True, BS], ids=["auto_shard", "no_auto_shard"]) def test_unet_conv( device, use_program_cache, @@ -1538,7 +1481,7 @@ def test_unet_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, output_layout, @@ -1567,10 +1510,9 @@ def test_unet_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, - padded_input_channels=16 if input_channels == 3 else None, output_layout=output_layout, auto_shard=auto_shard, ) @@ -1579,27 +1521,27 @@ def test_unet_conv( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant", ( # unet convs with batch size 2 # unique convs in unet (complete list) - (2, 16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (2, 32, 16, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 64, 64, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 64, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (2, 16, 48, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - # (2, 1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, True, {"act_block_h": 5 * 32}, False) # Enable when issue #11490 resolved + (2, 16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (2, 32, 16, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 64, 64, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 64, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (2, 16, 48, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (2, 16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (2, 16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (2, 16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + # (2, 1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, HS, {"act_block_h": 5 * 32}, False) # Enable when issue #11490 resolved ), ) @pytest.mark.parametrize( @@ -1630,7 +1572,7 @@ def test_unet_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, output_layout, @@ -1658,11 +1600,9 @@ def test_unet_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH - padded_input_channels=None, output_layout=output_layout, auto_shard=auto_shard, ) @@ -1679,24 +1619,24 @@ def test_unet_conv_wh( ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant", ( - (16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, True), - (32, 16, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (64, 32, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (64, 64, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 96, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 64, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (16, 48, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, True), - (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 8 * 32}, True), - (16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 16 * 32}, True), - (1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, True, {"act_block_h": 5 * 32}, False), + (16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 8 * 32}, True), + (32, 16, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (64, 32, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (64, 64, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 96, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 64, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (16, 48, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 8 * 32}, True), + (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 8 * 32}, True), + (16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 16 * 32}, True), + (1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, HS, {"act_block_h": 5 * 32}, False), ), ) @pytest.mark.parametrize( @@ -1727,7 +1667,7 @@ def test_unet_conv_groups_2_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, output_layout, @@ -1756,11 +1696,9 @@ def test_unet_conv_groups_2_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH - padded_input_channels=None, output_layout=output_layout, auto_shard=auto_shard, groups=groups, @@ -1778,24 +1716,24 @@ def test_unet_conv_groups_2_wh( ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant", ( - (16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - (16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - (32, 16, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (64, 32, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (64, 64, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 96, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 64, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - # (16, 48, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution - (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - # (16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution - (1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, True, {"act_block_h": 2 * 32}, False), + (16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + (16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + (32, 16, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (64, 32, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (64, 64, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 96, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 64, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + # (16, 48, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution + (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + # (16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution + (1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, HS, {"act_block_h": 2 * 32}, False), ), ) @pytest.mark.parametrize( @@ -1808,7 +1746,6 @@ def test_unet_conv_groups_2_wh( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) -@pytest.mark.parametrize("auto_shard", [False], ids=["no_auto_shard"]) def test_unet_conv_groups_4_6_wh( device, use_program_cache, @@ -1826,11 +1763,10 @@ def test_unet_conv_groups_4_6_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, output_layout, - auto_shard, groups, ): if (device.compute_with_storage_grid_size().x, device.compute_with_storage_grid_size().y) == (8, 7): @@ -1855,13 +1791,10 @@ def test_unet_conv_groups_4_6_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH - padded_input_channels=None, output_layout=output_layout, - auto_shard=auto_shard, groups=groups, ) @@ -1877,23 +1810,23 @@ def test_unet_conv_groups_4_6_wh( ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant", ( - (16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - # (16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution - (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - (32, 16, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (64, 32, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (64, 64, 66, 10, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 96, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, True, None, False), - # (32, 64, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), # OOM - need inplace convolution - (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, True, None, False), - # (16, 48, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution - (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), - # (16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution + (16, 4, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + # (16, 16, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution + (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + (32, 16, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (64, 32, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (64, 64, 66, 10, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 96, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + (32, 32, 132, 20, 3, 3, 1, 1, 1, 1, HS, None, False), + # (32, 64, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), # OOM - need inplace convolution + (32, 32, 264, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + # (16, 48, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution + (16, 16, 528, 80, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), + # (16, 32, 1056, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution # (1, 16, 1056, 160, 1, 1, 1, 1, 0, 0, True, {"act_block_h": 2 * 32}, True), # OOM - need inplace convolution ), ) @@ -1925,7 +1858,7 @@ def test_unet_conv_groups_8_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, output_layout, @@ -1954,11 +1887,9 @@ def test_unet_conv_groups_8_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH - padded_input_channels=None, output_layout=output_layout, auto_shard=auto_shard, groups=groups, @@ -1978,12 +1909,12 @@ def test_unet_conv_groups_8_wh( (2, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, {"act_reshard_num_cores_nhw": 4, "num_cores_nhw": 8}), ), ) -@pytest.mark.parametrize("use_1d_systolic_array", [False, True]) +@pytest.mark.parametrize("shard_layout", [BS, HS]) @pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_halo_reshard_conv( device, use_program_cache, - use_1d_systolic_array, + shard_layout, batch_size, output_channels, input_channels, @@ -2018,8 +1949,8 @@ def test_halo_reshard_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, auto_shard=auto_shard, ) @@ -2036,12 +1967,12 @@ def test_halo_reshard_conv( (1, 64, 64, 23, 23, 3, 3, 1, 1, 1, 1, {"num_cores_nhw": 10}, True), ), ) -@pytest.mark.parametrize("use_1d_systolic_array", [False, True]) +@pytest.mark.parametrize("shard_layout", [BS, HS]) @pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_conv_core_nondivis( device, use_program_cache, - use_1d_systolic_array, + shard_layout, batch_size, output_channels, input_channels, @@ -2080,8 +2011,8 @@ def test_conv_core_nondivis( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, auto_shard=auto_shard, ) @@ -2094,17 +2025,17 @@ def test_conv_core_nondivis( @pytest.mark.parametrize( "output_channels, input_channels, input_height, input_width, act_block_w_div, shard_layout", ( - (768, 768, 16, 16, 1, ttnn.TensorMemoryLayout.WIDTH_SHARDED), - (1280, 1280, 16, 16, 1, ttnn.TensorMemoryLayout.WIDTH_SHARDED), - (1280, 1280, 8, 8, 1, ttnn.TensorMemoryLayout.WIDTH_SHARDED), - (1280, 2560, 8, 8, 1, ttnn.TensorMemoryLayout.WIDTH_SHARDED), - (128, 128, 8, 8, 1, ttnn.TensorMemoryLayout.BLOCK_SHARDED), - (128, 128, 16, 16, 1, ttnn.TensorMemoryLayout.BLOCK_SHARDED), - (128, 128, 32, 32, 1, ttnn.TensorMemoryLayout.BLOCK_SHARDED), - (32, 32, 64, 64, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), - (32, 32, 128, 64, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), - (16, 16, 528, 80, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), - (32, 16, 264, 40, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED), + (768, 768, 16, 16, 1, WS), + (1280, 1280, 16, 16, 1, WS), + (1280, 1280, 8, 8, 1, WS), + (1280, 2560, 8, 8, 1, WS), + (128, 128, 8, 8, 1, BS), + (128, 128, 16, 16, 1, BS), + (128, 128, 32, 32, 1, BS), + (32, 32, 64, 64, 1, HS), + (32, 32, 128, 64, 1, HS), + (16, 16, 528, 80, 1, HS), + (32, 16, 264, 40, 1, HS), ), ) @pytest.mark.parametrize( @@ -2117,7 +2048,6 @@ def test_conv_core_nondivis( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) -@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) @pytest.mark.parametrize( "filter, dilation, pad", [ @@ -2143,7 +2073,6 @@ def test_conv_dilation( pad, output_layout, dilation, - auto_shard, ): config_override = {"act_block_w_div": act_block_w_div} run_conv( @@ -2162,13 +2091,11 @@ def test_conv_dilation( stride, pad, pad, - True, config_override, shard_layout=shard_layout, output_layout=output_layout, dilation=dilation, has_bias=False, - auto_shard=auto_shard, ) @@ -2176,32 +2103,32 @@ def test_conv_dilation( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, shard_layout, config_override, use_shallow_conv_variant", ( - (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, 2, True, None, False), - (1, 64, 64, 32, 32, 3, 3, 1, 1, 1, 1, 64, True, None, False), - (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 1, True, None, False), - (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 2, True, None, False), - (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True, None, False), - (1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, True, None, False), - (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 64, True, None, False), - (4, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, True, None, False), - (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 128, True, None, False), - # (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, False, None, False), circular buffer error - # (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 256, False, None, False), # doesn't fit with bfloat16 weights - # (32, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, False, None, False), # doesn't fit with bfloat16 weights - (32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 40, False, None, False), - (32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 10, False, None, False), - (1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, True, None, False), - (1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 16, True, None, False), - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 32, True, None, False), - (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 2, False, None, False), - (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 4, False, None, False), - (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, 2, False, None, False), - (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, 320, False, None, False), - # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, False, None, False), # doesn't fit with bfloat16 weights - (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, 32, True, None, False), - (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, 2, True, None, False), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, 2, HS, None, False), + (1, 64, 64, 32, 32, 3, 3, 1, 1, 1, 1, 64, HS, None, False), + (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 1, HS, None, False), + (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 2, HS, None, False), + (2, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, HS, None, False), + (1, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 1, HS, None, False), + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 64, HS, None, False), + (4, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, 128, HS, None, False), + (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 128, HS, None, False), + # (8, 256, 256, 28, 28, 3, 3, 2, 2, 1, 1, 256, BS, None, False), circular buffer error + # (16, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 256, BS, None, False), # doesn't fit with bfloat16 weights + # (32, 512, 512, 14, 14, 3, 3, 2, 2, 1, 1, 512, BS, None, False), # doesn't fit with bfloat16 weights + (32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 40, BS, None, False), + (32, 160, 160, 7, 7, 3, 3, 1, 1, 1, 1, 10, BS, None, False), + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 8, HS, None, False), + (1, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, 16, HS, None, False), + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, 32, HS, None, False), + (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 2, BS, None, False), + (8, 256, 256, 14, 14, 3, 3, 1, 1, 1, 1, 4, BS, None, False), + (1, 320, 320, 32, 32, 3, 3, 1, 1, 1, 1, 2, BS, None, False), + (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, 320, BS, None, False), + # (1, 1280, 1280, 32, 32, 3, 3, 1, 1, 1, 1, 1, BS, None, False), # doesn't fit with bfloat16 weights + (2, 64, 32, 66, 10, 3, 3, 1, 1, 1, 1, 32, HS, None, False), + (2, 32, 96, 132, 20, 3, 3, 1, 1, 1, 1, 2, HS, None, False), ), ) @pytest.mark.parametrize( @@ -2233,7 +2160,7 @@ def test_conv_groups( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, groups, @@ -2255,8 +2182,8 @@ def test_conv_groups( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, groups=groups, output_layout=output_layout, @@ -2265,56 +2192,56 @@ def test_conv_groups( @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant, groups", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant, groups", ( # yolov4 convs with batch size 1 # unique convs in yolov4 (complete list) # groups: number - # (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32 - # (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 32), # groups: 32 - # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 - # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 - # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 - # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, True, None, False, 64), # groups: 64 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False, 128), # groups: 128 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False, 256), # groups: 256 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 512), # groups: 512 - (1, 128, 128, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False, 2), # groups: 512 + # (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, HS, None, False, 32), # groups: 32 + # (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, HS, None, False, 32), # groups: 32 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, HS, None, False, 64), # groups: 64 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, HS, None, False, 64), # groups: 64 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, HS, None, False, 64), # groups: 64 + # (1, 64, 64, 480, 640, 3, 3, 1, 1, 1, 1, HS, None, False, 64), # groups: 64 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 128, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False, 128), # groups: 128 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 256, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False, 256), # groups: 256 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + # (1, 512, 512, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 512), # groups: 512 + (1, 128, 128, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False, 2), # groups: 512 ), ) @pytest.mark.parametrize( @@ -2323,11 +2250,9 @@ def test_conv_groups( ) @pytest.mark.parametrize( "activations_dtype", - # [ttnn.bfloat8_b, ttnn.bfloat16], [ttnn.bfloat8_b], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -# @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) @pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_yolov4_conv_groups_larger_than_one( @@ -2347,7 +2272,7 @@ def test_yolov4_conv_groups_larger_than_one( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, groups, @@ -2374,11 +2299,10 @@ def test_yolov4_conv_groups_larger_than_one( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, groups=groups, - padded_input_channels=16 if input_channels == 3 else None, output_layout=output_layout, auto_shard=auto_shard, ) @@ -2386,8 +2310,8 @@ def test_yolov4_conv_groups_larger_than_one( @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - " output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant, groups", - ((96, 3, 512, 512, 4, 4, 4, 4, 0, 0, True, None, False, 1),), + " output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant, groups", + ((96, 3, 512, 512, 4, 4, 4, 4, 0, 0, HS, None, False, 1),), ) @pytest.mark.parametrize( "weights_dtype", @@ -2421,7 +2345,7 @@ def test_swin_s_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, groups, @@ -2448,89 +2372,11 @@ def test_swin_s_conv( stride_w, pad_h, pad_w, - use_1d_systolic_array, - config_override, - use_shallow_conv_variant=use_shallow_conv_variant, - groups=groups, - output_layout=output_layout, - auto_shard=auto_shard, - ) - - -@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) -@pytest.mark.parametrize( - "batch_size, input_channels, output_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, groups, shard_layout, config_override, use_shallow_conv_variant", - ( - (1, 32, 32, 128, 128, 8, 8, 8, 8, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - (1, 64, 64, 64, 64, 4, 4, 4, 4, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - (1, 256, 150, 128, 128, 1, 1, 1, 1, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - (1, 32, 16, 64, 64, 1, 1, 1, 1, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - (1, 96, 24, 32, 32, 1, 1, 1, 1, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - (1, 576, 576, 8, 8, 3, 3, 1, 1, 0, 0, 576, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None, False), - (1, 576, 576, 8, 8, 3, 3, 2, 2, 0, 0, 576, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None, False), - (1, 960, 960, 4, 4, 3, 3, 1, 1, 0, 0, 960, ttnn.TensorMemoryLayout.WIDTH_SHARDED, None, False), - (1, 144, 24, 32, 32, 1, 1, 1, 1, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - (1, 144, 32, 16, 16, 1, 1, 1, 1, 0, 0, 1, ttnn.TensorMemoryLayout.HEIGHT_SHARDED, None, False), - ), -) -@pytest.mark.parametrize( - "weights_dtype", - [ttnn.bfloat16], -) -@pytest.mark.parametrize( - "activations_dtype", - [ttnn.bfloat8_b, ttnn.bfloat16], -) -@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -@pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) -@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) -@skip_for_grayskull() -def test_conv_for_segformer_512x512( - device, - use_program_cache, - math_fidelity, - activations_dtype, - weights_dtype, - batch_size, - output_channels, - input_channels, - input_height, - input_width, - filter_height, - filter_width, - stride_h, - stride_w, - pad_h, - pad_w, - shard_layout, - config_override, - use_shallow_conv_variant, - groups, - output_layout, - auto_shard, -): - run_conv( - device, - math_fidelity, - activations_dtype, - weights_dtype, - batch_size, - output_channels, - input_channels, - input_height, - input_width, - filter_height, - filter_width, - stride_h, - stride_w, - pad_h, - pad_w, - False, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, groups=groups, output_layout=output_layout, - shard_layout=shard_layout, auto_shard=auto_shard, ) @@ -2538,13 +2384,13 @@ def test_conv_for_segformer_512x512( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, dilation_h, dilation_w, groups, use_1d_systolic_array", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, dilation, shard_layout", ( - (1, 48, 32, 252, 252, 3, 3, 1, 1, 0, 0, 2, 2, 1, True), - (1, 56, 48, 248, 248, 3, 3, 1, 1, 0, 0, 4, 4, 1, True), - (1, 64, 56, 240, 240, 3, 3, 1, 1, 0, 0, 8, 8, 1, True), - (1, 48, 32, 124, 124, 3, 3, 1, 1, 0, 0, 2, 2, 1, True), - (1, 56, 48, 120, 120, 3, 3, 1, 1, 0, 0, 4, 4, 1, True), + (1, 48, 32, 252, 252, 3, 3, 1, 1, 0, 0, 2, HS), + (1, 56, 48, 248, 248, 3, 3, 1, 1, 0, 0, 4, HS), + (1, 64, 56, 240, 240, 3, 3, 1, 1, 0, 0, 8, HS), + (1, 48, 32, 124, 124, 3, 3, 1, 1, 0, 0, 2, HS), + (1, 56, 48, 120, 120, 3, 3, 1, 1, 0, 0, 4, HS), ), ) @pytest.mark.parametrize( @@ -2574,10 +2420,8 @@ def test_model_k_256x256( stride_w, pad_h, pad_w, - dilation_h, - dilation_w, - groups, - use_1d_systolic_array, + dilation, + shard_layout, auto_shard, ): run_conv( @@ -2596,37 +2440,37 @@ def test_model_k_256x256( stride_w, pad_h, pad_w, - use_1d_systolic_array, None, - dilation=dilation_h, + shard_layout=shard_layout, + dilation=dilation, auto_shard=auto_shard, ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels,input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override, use_shallow_conv_variant", + "batch_size, output_channels,input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override, use_shallow_conv_variant", ( - (1, 32, 3, 480, 640, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 64}, True), - (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 32}, False), - (1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None, False), - (1, 64, 64, 240, 320, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 64}, False), - (1, 128, 64, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False), - (1, 128, 128, 120, 160, 3, 3, 1, 1, 1, 1, True, None, False), - (1, 256, 128, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False), - (1, 256, 256, 60, 80, 3, 3, 1, 1, 1, 1, True, None, False), - (1, 512, 256, 30, 40, 3, 3, 1, 1, 1, 1, True, None, False), - (1, 512, 512, 30, 40, 3, 3, 1, 1, 1, 1, False, None, False), - (1, 256, 512, 60, 80, 3, 3, 1, 1, 1, 1, False, {"act_block_h": 32}, False), - (1, 128, 256, 120, 160, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 32}, False), - (1, 64, 128, 240, 320, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 32}, False), - (1, 32, 64, 256, 256, 3, 3, 1, 1, 1, 1, True, {"act_block_h": 32}, False), - (1, 1, 32, 480, 640, 1, 1, 1, 1, 0, 0, True, None, False), + (1, 32, 3, 480, 640, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 64}, True), + (1, 32, 32, 480, 640, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 32}, False), + (1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None, False), + (1, 64, 64, 240, 320, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 64}, False), + (1, 128, 64, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False), + (1, 128, 128, 120, 160, 3, 3, 1, 1, 1, 1, HS, None, False), + (1, 256, 128, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False), + (1, 256, 256, 60, 80, 3, 3, 1, 1, 1, 1, HS, None, False), + (1, 512, 256, 30, 40, 3, 3, 1, 1, 1, 1, HS, None, False), + (1, 512, 512, 30, 40, 3, 3, 1, 1, 1, 1, BS, None, False), + (1, 256, 512, 60, 80, 3, 3, 1, 1, 1, 1, BS, {"act_block_h": 32}, False), + (1, 128, 256, 120, 160, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 32}, False), + (1, 64, 128, 240, 320, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 32}, False), + (1, 32, 64, 256, 256, 3, 3, 1, 1, 1, 1, HS, {"act_block_h": 32}, False), + (1, 1, 32, 480, 640, 1, 1, 1, 1, 0, 0, HS, None, False), ), ) @pytest.mark.parametrize( "weights_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], + [ttnn.bfloat16], ) @pytest.mark.parametrize( "activations_dtype", @@ -2652,7 +2496,7 @@ def test_conv_for_vanilla_unet( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, use_shallow_conv_variant, output_layout, @@ -2675,8 +2519,8 @@ def test_conv_for_vanilla_unet( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, groups=1, output_layout=output_layout, @@ -2686,24 +2530,24 @@ def test_conv_for_vanilla_unet( @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( # unique convs in rn50 (complete list) # first conv post folding and input_channels padding to tile width - (16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, True, None), + (16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, HS, None), # rn50 layer1 - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), - (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), - (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, True, None), + (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), + (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), + (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), # rn50 layer2 - (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, True, None), - (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, True, None), - (1, 32, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None), - (1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, True, None), + (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), + (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), + (1, 32, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None), + (1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None), ), ) @pytest.mark.parametrize( @@ -2735,7 +2579,7 @@ def test_non_tile_multiple_height_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, fp32_accum, packer_l1_acc, @@ -2758,7 +2602,7 @@ def test_non_tile_multiple_height_conv_wh( ): pytest.skip("Skipping test because it won't fit in L1!") - if activations_dtype == ttnn.float32 and (batch_size >= 16 or (output_channels == 64 and input_height == 240)): + if activations_dtype == ttnn.float32 and (batch_size >= 16 or (output_channels == 64 or input_height >= 240)): pytest.skip("Skipping test because it won't fit in L1!") if ( @@ -2788,10 +2632,9 @@ def test_non_tile_multiple_height_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override=config_override, + shard_layout=shard_layout, use_shallow_conv_variant=use_shallow_conv_variant, - transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH packer_l1_acc=packer_l1_acc, fp32_accum=fp32_accum, has_bias=has_bias, @@ -2802,30 +2645,30 @@ def test_non_tile_multiple_height_conv_wh( @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, use_1d_systolic_array, config_override", + "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", ( - (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 128, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 192, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 256, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 320, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 384, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 448, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 512, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 576, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 64, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 64, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 128, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 192, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 256, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 320, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 384, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 448, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 512, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 576, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 128, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 320, 320, 16, 16, 3, 3, 1, 1, 1, 1, False, None), - (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, False, None), + (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 128, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 192, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 256, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 384, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 448, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 512, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 576, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 64, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 128, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 192, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 256, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 384, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 448, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 512, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 576, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 128, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 320, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), + (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), ), ) @pytest.mark.parametrize( @@ -2837,7 +2680,6 @@ def test_non_tile_multiple_height_conv_wh( [ttnn.bfloat16], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -@pytest.mark.parametrize("enable_auto_formatting", [False]) def test_non_tile_multiple_width_conv_wh( device, use_program_cache, @@ -2855,9 +2697,8 @@ def test_non_tile_multiple_width_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, + shard_layout, config_override, - enable_auto_formatting, ): run_conv( device, @@ -2875,12 +2716,9 @@ def test_non_tile_multiple_width_conv_wh( stride_w, pad_h, pad_w, - use_1d_systolic_array, config_override, + shard_layout=shard_layout, use_shallow_conv_variant=(input_channels == 16), - transpose_mcast=use_1d_systolic_array, - enable_auto_formatting=enable_auto_formatting, - padded_input_channels=16 if input_channels == 16 else None, output_layout=ttnn.ROW_MAJOR_LAYOUT, ) @@ -2907,7 +2745,7 @@ def test_shallow_conv_with_tiled_input(device): tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels)) tt_input = ttnn.to_layout(tt_input, ttnn.TILE_LAYOUT) - [tt_out, [out_height, out_width], [_, _]] = ttnn.conv2d( + [tt_out, [out_height, out_width]] = ttnn.conv2d( input_tensor=tt_input, weight_tensor=tt_kernel, in_channels=in_channels, @@ -2924,7 +2762,6 @@ def test_shallow_conv_with_tiled_input(device): groups=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, return_output_dim=True, - return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_out) diff --git a/tests/ttnn/unit_tests/operations/test_pad.py b/tests/ttnn/unit_tests/operations/test_pad.py index 909fcff1168..00ef1461791 100644 --- a/tests/ttnn/unit_tests/operations/test_pad.py +++ b/tests/ttnn/unit_tests/operations/test_pad.py @@ -226,8 +226,10 @@ def test_pad_rm_sharded_stickwise( ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, dtype=ttnn.float32, layout=ttnn.ROW_MAJOR_LAYOUT, device=device ) - ttnn_sharded_input_tensor = ttnn.to_memory_config(ttnn_input_tensor, input_shard_memory_config) - + # Still relay on keep_l1_aligned = True to make it work with the current implementation + ttnn_sharded_input_tensor = ttnn.interleaved_to_sharded( + ttnn_input_tensor, input_shard_memory_config, keep_l1_aligned=True + ) padded_tensor = ttnn.pad(ttnn_sharded_input_tensor, pad_to_shape, input_tensor_start, pad_value) tt_output_tensor = ttnn.to_memory_config(padded_tensor, ttnn.L1_MEMORY_CONFIG) diff --git a/tests/ttnn/unit_tests/operations/test_reduction.py b/tests/ttnn/unit_tests/operations/test_reduction.py index 399ddcf2e25..87a6fb3c584 100644 --- a/tests/ttnn/unit_tests/operations/test_reduction.py +++ b/tests/ttnn/unit_tests/operations/test_reduction.py @@ -7,8 +7,9 @@ import torch import ttnn + from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import skip_for_grayskull +from models.utility_functions import skip_for_grayskull, torch_random @pytest.mark.parametrize("batch_size", [1, 16]) @@ -304,3 +305,60 @@ def test_mean_2d_tensor_dims(device, h, w, dim, keepdim): output_tensor = ttnn.to_torch(output_tensor) assert_with_pcc(torch_output_tensor, output_tensor, pcc=0.99) + + +def run_maxpool(device, input_shape, kernel_size, stride, padding, dilation): + torch_input = torch.rand(input_shape, dtype=torch.bfloat16) + batch_size, in_c, in_h, in_w = input_shape + input_tensor = torch.permute(torch_input, (0, 2, 3, 1)) + input_tensor = torch.reshape(input_tensor, (1, 1, -1, in_c)) + input_tensor = ttnn.from_torch(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT, device=device) + output_tensor = ttnn.max_pool2d( + input_tensor, + batch_size, + in_h, + in_w, + in_c, + kernel_size, + stride, + padding, + dilation, + ) + + torch_output_tensor = torch.nn.functional.max_pool2d(torch_input, kernel_size, stride, padding) + + output_tensor = ttnn.to_torch(output_tensor) + _, out_c, out_h, out_w = torch_output_tensor.shape + output_tensor = torch.reshape(output_tensor, (batch_size, out_h, out_w, out_c)) + output_tensor = torch.permute(output_tensor, (0, 3, 1, 2)) + assert_with_pcc(output_tensor, torch_output_tensor) + + +def run_reduce_sum_h(device, batch_size, h, w, dim): + torch_input_tensor = torch_random((batch_size, h, w), -1, 1, dtype=torch.bfloat16) + torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + output_tensor = ttnn.mean(input_tensor, dim=dim) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(torch_output_tensor, output_tensor) + + +@skip_for_grayskull("Not a tile size multiple, will fail on GS. #17132") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 4096}], indirect=True) +@pytest.mark.parametrize( + "input_shape", + [ + (1, 192, 56, 56), # Multi core face height not default + ], +) +@pytest.mark.parametrize( + "kernel_size", + [ + (2, 2), # Small kernel + (5, 5), # Large kernel + ], +) +def test_run_reduce_sum_h_after_max_pool(device, input_shape, kernel_size): + run_maxpool(device, input_shape, kernel_size, kernel_size, (0, 0), (1, 1)) + run_reduce_sum_h(device, 1, 32, 32, -2) diff --git a/tests/ttnn/unit_tests/operations/test_repeat.py b/tests/ttnn/unit_tests/operations/test_repeat.py index 59589cd53db..73af42df968 100644 --- a/tests/ttnn/unit_tests/operations/test_repeat.py +++ b/tests/ttnn/unit_tests/operations/test_repeat.py @@ -1,43 +1,78 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 -import pytest +from functools import reduce +from math import prod +import pytest import torch - import ttnn from models.utility_functions import comp_pcc from tests.ttnn.utils_for_testing import assert_with_pcc -from functools import reduce -dtypes = [torch.bfloat16] -shapes = [(1, 2, 4, 4), (1, 1, 1, 1)] -repeat_specs = [(1, 2, 1, 1), (2, 2, 2, 2)] +layouts = [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT] + +dtypes = [(torch.float32, ttnn.float32), (torch.bfloat16, ttnn.bfloat16), (torch.bfloat16, ttnn.bfloat8_b)] +shapes = [(1,), (2,), (2, 1), (2, 3), (2, 1, 3), (4, 16, 3, 2), (4, 3, 1, 2, 2)] +repeat_shapes = [ + (1,), + (2,), + (1, 2), + (1, 4), + (2, 1, 3), + (1, 2, 3), + (4, 3, 2, 1), + (2, 3, 4, 5, 2), + (2, 1, 3, 1, 3, 1), + (2048,), +] + + +def _get_size(larger, smaller) -> int: + return prod([a * b for a, b in zip(((1,) * (len(larger) - len(smaller)) + smaller), larger)]) -shape_and_repeat_specs = list(zip(shapes, repeat_specs)) +def _get_final_size(shape, reshape): + if len(shape) > len(reshape): + return _get_size(shape, reshape) + else: + return _get_size(reshape, shape) + +@pytest.mark.parametrize("layout", layouts) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("shape_and_repeat_spec", shape_and_repeat_specs) -def test_repeat(device, dtype, shape_and_repeat_spec): - shape, repeat_shape = shape_and_repeat_spec - if dtype == torch.bfloat16 and shape[-1] < 2 and repeat_shape[-1] < 2: - pytest.skip("bfloat16 needs 4 byte inner dim on the output.") +@pytest.mark.parametrize("shape", shapes) +@pytest.mark.parametrize("repeat_shape", repeat_shapes) +def test_repeat(device, layout, dtype, shape, repeat_shape): + torch_dtype, ttnn_dtype = dtype + + # trying to avoid the `buffer not divisible by page size` error. Does this make sense? + if layout == ttnn.TILE_LAYOUT and ( + prod(shape) % ttnn.TILE_SIZE != 0 or _get_final_size(shape, repeat_shape) % ttnn.TILE_SIZE != 0 + ): + pytest.skip("Tensor not suitable for tile layout") + + if len(repeat_shape) < len(shape): + pytest.skip("PyTorch repeat dim must be >= tensor dim (although we can handle this).") + + if layout == ttnn.ROW_MAJOR_LAYOUT and ttnn_dtype == ttnn.bfloat8_b: + pytest.skip("Illegal config") mul = lambda x, y: x * y - torch_input_tensor = torch.arange(0, reduce(mul, shape, 1), dtype=dtype).reshape(shape) + torch_input_tensor = torch.arange(0, reduce(mul, shape, 1), dtype=torch_dtype).reshape(shape) torch_result = torch_input_tensor.repeat(repeat_shape) - - input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout, device=device, dtype=ttnn_dtype) output = ttnn.repeat(input_tensor, ttnn.Shape(repeat_shape)) output = ttnn.to_torch(output) - assert ( output.shape == torch_result.shape ), f"Output shape {output.shape} does not match torch shape {torch_result.shape}" assert_with_pcc(torch_result, output, 0.9999) + + +# TODO! test program cache when it is implemented diff --git a/tests/ttnn/unit_tests/test_expand.py b/tests/ttnn/unit_tests/test_expand.py index e202c61e1a9..e19daba93e8 100644 --- a/tests/ttnn/unit_tests/test_expand.py +++ b/tests/ttnn/unit_tests/test_expand.py @@ -8,24 +8,14 @@ @pytest.mark.parametrize( - "input_shape", + "input_shape, output_shape", [ - [1, 32], - [8, 1], - [32, 1], + [(4, 1), (4, 2)], + [(1, 32), (32, -1)], + [(1, 32), (64, 32)], + [(8, 1), (8, 8)], + [(8, 1), (-1, 32)], ], - ids=["1d", "2d", "large_2d"], -) -@pytest.mark.parametrize( - "output_shape", - [ - [-1, 32], - [32, -1, 32], - [32, 32, -1, 32], - [32, 32, 32, -1, 32], - [4, 4, 4096, -1, 32], - ], - ids=["2d", "3d", "4d", "5d", "random_large_5d"], ) @pytest.mark.parametrize( "tensor_layout", @@ -43,7 +33,7 @@ def test_expand(input_shape, output_shape, tensor_layout, device): output_tensor = ttnn.expand(input_tensor, output_shape) output_tensor = ttnn.to_torch(output_tensor) - assert torch.equal(torch_output_tensor, output_tensor) + assert torch.allclose(torch_output_tensor, output_tensor, atol=1e-1, rtol=1e-2) @pytest.mark.parametrize( @@ -56,7 +46,7 @@ def test_expand(input_shape, output_shape, tensor_layout, device): def test_expand_callback(tensor_layout, device, use_program_cache): num_program_cache_entries_list = [] for i in range(2): - test_expand([32, 1], [32, 32, 32], tensor_layout, device) + test_expand([32, 1], [32, 32], tensor_layout, device) num_program_cache_entries_list.append(device.num_program_cache_entries()) assert num_program_cache_entries_list[0] > 0 diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index c4ea6a0fd2a..f81039d1728 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -694,3 +694,25 @@ def model(submesh): submesh_devices = mesh_device.create_submeshes(ttnn.MeshShape(2, 2)) for submesh in submesh_devices: model(submesh) + + +@pytest.mark.parametrize("mesh_device", [pytest.param((1, 8), id="1x8_line")], indirect=True) +def test_line_all_gather_after_reshape(mesh_device): + if mesh_device.get_num_devices() < 8: + pytest.skip() + mesh_device.reshape(ttnn.MeshShape(2, 4)) + torch_input_tensor = torch.rand((1, 1, 64, 128), dtype=torch.bfloat16) + + mesh_tensor = ttnn.from_torch( + torch_input_tensor, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=list(mesh_device.shape), dims=(2, 3)), + ) + output_tensor = ttnn.all_gather( + mesh_tensor, + dim=2, + cluster_axis=0, + mesh_device=mesh_device, + topology=ttnn.Topology.Linear, + ) diff --git a/tests/ttnn/unit_tests/test_single_device_trace.py b/tests/ttnn/unit_tests/test_single_device_trace.py index 2e4ad3d26f5..8ffa2774978 100644 --- a/tests/ttnn/unit_tests/test_single_device_trace.py +++ b/tests/ttnn/unit_tests/test_single_device_trace.py @@ -11,7 +11,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc -@pytest.mark.parametrize("shape", [[1, 3, 1024, 1024], (1, 1, 512, 512), (1, 3, 512, 512), (1, 3, 32, 32)]) +@pytest.mark.parametrize("shape", [[1, 3, 1024, 1024], (1, 1, 512, 512), (1, 3, 32, 32)]) @pytest.mark.parametrize("enable_async", [True, False]) @pytest.mark.parametrize("blocking", [True, False]) @pytest.mark.parametrize("device_params", [{"trace_region_size": 200000}], indirect=True) @@ -71,7 +71,7 @@ def run_op_chain(input_0, input_1): device.enable_async(False) -@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 512, 512), (1, 3, 32, 32)]) +@pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 1, 32, 32), (1, 3, 32, 32)]) @pytest.mark.parametrize("enable_async", [True, False]) @pytest.mark.parametrize("blocking", [True, False]) @pytest.mark.parametrize("device_params", [{"trace_region_size": 266240}], indirect=True) diff --git a/tests/ttnn/unit_tests/test_to_layout.py b/tests/ttnn/unit_tests/test_to_layout.py index 135c119072b..4e0fe5c29bc 100644 --- a/tests/ttnn/unit_tests/test_to_layout.py +++ b/tests/ttnn/unit_tests/test_to_layout.py @@ -283,3 +283,59 @@ def test_to_layout_page_error(shape, device): torch_output = torch_tensor assert torch_output.shape == output_tensor.shape assert_with_pcc(torch_output, output_tensor, 0.9999) + + +@pytest.mark.parametrize("shape", [[64, 7680]]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT]) +def test_untilize_w1(shape, input_layout, output_layout, device): + torch.manual_seed(0) + input_a = torch.randn(shape, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.untilize_with_unpadding(input_tensor, [36, 7667]) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(input_a[:37, :7668], output_tensor) + + +@pytest.mark.parametrize("shape", [[2, 32, 6144]]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT]) +def test_untilize_w2(shape, input_layout, output_layout, device): + torch.manual_seed(0) + input_a = torch.randn(shape, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.untilize_with_unpadding(input_tensor, [1, 30, 6140]) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(input_a[:, :31, :6141], output_tensor) + + +@pytest.mark.parametrize("shape", [[1, 1, 32, 1536]]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT]) +def test_untilize_w3(shape, input_layout, output_layout, device): + torch.manual_seed(0) + input_a = torch.randn(shape, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.untilize_with_unpadding(input_tensor, [0, 0, 31, 1535]) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(input_a[:, :, :32, :1536], output_tensor) + + +@pytest.mark.parametrize("shape", [[1, 1, 32, 10912]]) +@pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT]) +@pytest.mark.parametrize("input_layout", [ttnn.TILE_LAYOUT]) +def test_untilize_w4(shape, input_layout, output_layout, device): + torch.manual_seed(0) + input_a = torch.randn(shape, dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(input_a, device=device, layout=input_layout, dtype=ttnn.bfloat16) + output_tensor = ttnn.untilize_with_unpadding(input_tensor, [0, 0, 0, 10911]) + output_tensor = ttnn.to_torch(output_tensor) + + assert_with_pcc(input_a[:, :, :1, :10912], output_tensor) diff --git a/tt-train/cmake/dependencies.cmake b/tt-train/cmake/dependencies.cmake index ecf77f6dfb4..d9ea7849b21 100644 --- a/tt-train/cmake/dependencies.cmake +++ b/tt-train/cmake/dependencies.cmake @@ -1,13 +1,25 @@ -set(ENV{CPM_SOURCE_CACHE} "${PROJECT_SOURCE_DIR}/.cpmcache") +############################################################################################################################ +# CPM +############################################################################################################################ +include(${PROJECT_SOURCE_DIR}/cmake/CPM.cmake) ############################################################################################################################ # Boost ############################################################################################################################ -include(${PROJECT_SOURCE_DIR}/cmake/fetch_boost.cmake) -fetch_boost_library(core) -fetch_boost_library(smart_ptr) -fetch_boost_library(container) +CPMAddPackage( + NAME Boost + VERSION 1.86.0 + URL + https://github.com/boostorg/boost/releases/download/boost-1.86.0/boost-1.86.0-cmake.tar.xz + URL_HASH + SHA256=2c5ec5edcdff47ff55e27ed9560b0a0b94b07bd07ed9928b476150e16b0efc57 + OPTIONS + "BOOST_ENABLE_CMAKE ON" + "BOOST_SKIP_INSTALL_RULES ON" + "BUILD_SHARED_LIBS OFF" + "BOOST_INCLUDE_LIBRARIES core\\\;container\\\;smart_ptr" +) ############################################################################################################################ # yaml-cpp @@ -76,12 +88,13 @@ CPMAddPackage(NAME taskflow GITHUB_REPOSITORY taskflow/taskflow GIT_TAG v3.7.0 O include(${PROJECT_SOURCE_DIR}/cmake/fetch_cli11.cmake) +# gersemi: off CPMAddPackage( NAME msgpack GIT_REPOSITORY https://github.com/msgpack/msgpack-c.git GIT_TAG cpp-6.1.0 - PATCHES - msgpack.patch + PATCH_COMMAND + patch --dry-run -p1 -R < ${CMAKE_CURRENT_LIST_DIR}/msgpack.patch || patch -p1 < ${CMAKE_CURRENT_LIST_DIR}/msgpack.patch OPTIONS "CMAKE_MESSAGE_LOG_LEVEL NOTICE" "MSGPACK_BUILD_EXAMPLES OFF" @@ -100,8 +113,9 @@ CPMAddPackage( NAME tokenizers-cpp GITHUB_REPOSITORY mlc-ai/tokenizers-cpp GIT_TAG 5de6f656c06da557d4f0fb1ca611b16d6e9ff11d - PATCHES - tokenizers-cpp.patch + PATCH_COMMAND + patch --dry-run -p1 -R < ${CMAKE_CURRENT_LIST_DIR}/tokenizers-cpp.patch || patch -p1 < ${CMAKE_CURRENT_LIST_DIR}/tokenizers-cpp.patch OPTIONS "CMAKE_MESSAGE_LOG_LEVEL NOTICE" ) +# gersemi: on diff --git a/tt-train/cmake/fetch_boost.cmake b/tt-train/cmake/fetch_boost.cmake deleted file mode 100644 index 4987d256c45..00000000000 --- a/tt-train/cmake/fetch_boost.cmake +++ /dev/null @@ -1,27 +0,0 @@ -include(${PROJECT_SOURCE_DIR}/cmake/CPM.cmake) - -function(fetch_boost_library BOOST_PROJECT_NAME) - CPMAddPackage( - NAME boost_${BOOST_PROJECT_NAME} - GITHUB_REPOSITORY boostorg/${BOOST_PROJECT_NAME} - GIT_TAG boost-1.85.0 - OPTIONS - "BUILD_SHARED_LIBS OFF" - ) - - get_target_property(BOOST_INTERFACE_LINK_LIBRARIES boost_${BOOST_PROJECT_NAME} INTERFACE_LINK_LIBRARIES) - - if(NOT BOOST_INTERFACE_LINK_LIBRARIES STREQUAL BOOST_INTERFACE_LINK_LIBRARIES-NOTFOUND) - foreach(BOOST_INTERFACE_LINK_LIBRARY IN ITEMS ${BOOST_INTERFACE_LINK_LIBRARIES}) - if( - NOT TARGET - ${BOOST_INTERFACE_LINK_LIBRARY} - AND BOOST_INTERFACE_LINK_LIBRARY - MATCHES - "^Boost::([a-z0-9_]+)$" - ) - fetch_boost_library(${CMAKE_MATCH_1}) - endif() - endforeach() - endif() -endfunction() diff --git a/tt-train/sources/examples/mnist_mlp/CMakeLists.txt b/tt-train/sources/examples/mnist_mlp/CMakeLists.txt index 0c26c08e294..64e53de5aac 100644 --- a/tt-train/sources/examples/mnist_mlp/CMakeLists.txt +++ b/tt-train/sources/examples/mnist_mlp/CMakeLists.txt @@ -3,6 +3,7 @@ project(mnist_mlp) set(SOURCES main.cpp utils.cpp + model.cpp ) CPMAddPackage(NAME mnist_dataset GITHUB_REPOSITORY wichtounet/mnist GIT_TAG master) include_directories(${mnist_dataset_SOURCE_DIR}/include) diff --git a/tt-train/sources/examples/mnist_mlp/main.cpp b/tt-train/sources/examples/mnist_mlp/main.cpp index 868e827d296..649e7463c26 100644 --- a/tt-train/sources/examples/mnist_mlp/main.cpp +++ b/tt-train/sources/examples/mnist_mlp/main.cpp @@ -2,25 +2,27 @@ // // SPDX-License-Identifier: Apache-2.0 +#include + #include #include #include #include #include -#include -#include -#include +#include #include "autograd/auto_context.hpp" #include "autograd/tensor.hpp" #include "core/tt_tensor_utils.hpp" #include "datasets/dataloader.hpp" #include "datasets/in_memory_dataset.hpp" +#include "model.hpp" #include "models/mlp.hpp" #include "ops/losses.hpp" #include "optimizers/sgd.hpp" +#include "serialization/serializable.hpp" #include "utils.hpp" -#include "yaml-cpp/node/node.h" + using ttml::autograd::TensorPtr; using DatasetSample = std::pair, uint8_t>; @@ -30,32 +32,10 @@ using DataLoader = ttml::datasets::DataLoader< std::function &&samples)>, BatchType>; -constexpr auto model_name = "mlp"; -constexpr auto optimizer_name = "optimizer"; +using Model = std::variant, std::shared_ptr>; -template -float evaluate(DataLoader &test_dataloader, Model &model, size_t num_targets) { - model->eval(); - float num_correct = 0; - float num_samples = 0; - for (const auto &[data, target] : test_dataloader) { - auto output = (*model)(data); - auto output_vec = ttml::core::to_vector(output->get_value()); - auto target_vec = ttml::core::to_vector(target->get_value()); - for (size_t i = 0; i < output_vec.size(); i += num_targets) { - auto predicted_class = std::distance( - output_vec.begin() + i, - std::max_element(output_vec.begin() + i, output_vec.begin() + (i + num_targets))); - auto target_class = std::distance( - target_vec.begin() + i, - std::max_element(target_vec.begin() + i, target_vec.begin() + (i + num_targets))); - num_correct += static_cast(predicted_class == target_class); - num_samples++; - } - } - model->train(); - return num_correct / num_samples; -}; +const std::string model_name = "mlp"; +const std::string optimizer_name = "optimizer"; struct TrainingConfig { uint32_t batch_size = 128; @@ -84,18 +64,100 @@ TrainingConfig parse_config(const YAML::Node &yaml_config) { return config; } +void initialize_device(bool enable_tp) { + if (enable_tp) { + // we support only N300 for now + ttml::autograd::ctx().set_mesh_shape({1, 2}); + } +} + +void model_to_eval(Model &model) { + std::visit([](auto &model) { model->eval(); }, model); +} + +void model_to_train(Model &model) { + std::visit([](auto &model) { model->train(); }, model); +} + +ttml::autograd::TensorPtr run_model(Model &model, const ttml::autograd::TensorPtr &data) { + return std::visit([&data](auto &model) { return (*model)(data); }, model); +} + +ttml::serialization::NamedParameters get_model_parameters(Model &model) { + return std::visit([](auto &model) { return model->parameters(); }, model); +} + +void load_model( + Model &model, + const TrainingConfig &config, + ttml::optimizers::SGD &optimizer, + const std::string &model_name, + const std::string &optimizer_name) { + std::visit( + [&config, &optimizer, &model_name, &optimizer_name](auto &model) { + load_training_state(config.model_path, model, optimizer, model_name, optimizer_name); + }, + model); +} + +void save_model( + Model &model, + const TrainingConfig &config, + ttml::optimizers::SGD &optimizer, + const std::string &model_name, + const std::string &optimizer_name) { + std::visit( + [&config, &optimizer, &model_name, &optimizer_name](auto &model) { + save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); + }, + model); +} + +template +float evaluate(DataLoader &test_dataloader, Model &model, size_t num_targets) { + model_to_eval(model); + float num_correct = 0; + float num_samples = 0; + auto *device = &ttml::autograd::ctx().get_device(); + for (const auto &[data, target] : test_dataloader) { + auto output = run_model(model, data); + ttml::core::MeshToXTensorVariant composer = ttml::core::VectorMeshToXTensor(device->shape()); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), composer)[0]; + auto target_xtensor = ttml::core::to_xtensor(target->get_value(), composer)[0]; + auto output_vec = std::vector(output_xtensor.begin(), output_xtensor.end()); + auto target_vec = std::vector(target_xtensor.begin(), target_xtensor.end()); + for (size_t i = 0; i < output_vec.size(); i += num_targets) { + auto predicted_class = std::distance( + output_vec.begin() + i, + std::max_element(output_vec.begin() + i, output_vec.begin() + (i + num_targets))); + auto target_class = std::distance( + target_vec.begin() + i, + std::max_element(target_vec.begin() + i, target_vec.begin() + (i + num_targets))); + num_correct += static_cast(predicted_class == target_class); + num_samples++; + } + } + model_to_train(model); + return num_correct / num_samples; +}; + int main(int argc, char **argv) { CLI::App app{"Mnist Example"}; argv = app.ensure_utf8(argv); std::string config_name = std::string(CONFIGS_FOLDER) + "/training_mnist_mlp.yaml"; bool is_eval = false; + bool enable_tp = false; app.add_option("-c,--config", config_name, "Yaml Config name")->default_val(config_name); - app.add_option("-e,--eval", config_name, "Evaluate")->default_val(is_eval); + app.add_option("-e,--eval", is_eval, "Evaluate")->default_val(is_eval); + app.add_option("-p, --enable_tp", enable_tp, "Enable tensor parallelism")->default_val(enable_tp); CLI11_PARSE(app, argc, argv); auto yaml_config = YAML::LoadFile(config_name); TrainingConfig config = parse_config(yaml_config); + + initialize_device(enable_tp); + // Load MNIST data const size_t num_targets = 10; const size_t num_features = 784; @@ -134,7 +196,12 @@ int main(int argc, char **argv) { auto train_dataloader = DataLoader(training_dataset, config.batch_size, /* shuffle */ true, collate_fn); auto test_dataloader = DataLoader(test_dataset, config.batch_size, /* shuffle */ false, collate_fn); - auto model = ttml::models::mlp::create(config.mlp_config); + Model model; + if (enable_tp) { + model = std::make_shared(); + } else { + model = ttml::models::mlp::create(config.mlp_config); + } const float learning_rate = config.learning_rate * (static_cast(config.batch_size) / 128.F); const float momentum = config.momentum; @@ -148,10 +215,15 @@ int main(int argc, char **argv) { fmt::print(" Dampening {}\n", sgd_config.dampening); fmt::print(" Weight decay: {}\n", sgd_config.weight_decay); fmt::print(" Nesterov: {}\n", sgd_config.nesterov); - auto optimizer = ttml::optimizers::SGD(model->parameters(), sgd_config); + auto parameters = get_model_parameters(model); + auto optimizer = ttml::optimizers::SGD(parameters, sgd_config); if (!config.model_path.empty() && std::filesystem::exists(config.model_path)) { - fmt::print("Loading model from {}\n", config.model_path); - load_training_state(config.model_path, model, optimizer, model_name, optimizer_name); + if (enable_tp) { + fmt::println("Loading model for tensor parallelism is not supported yet. Loading model has been skipped."); + } else { + fmt::print("Loading model from {}\n", config.model_path); + load_model(model, config, optimizer, model_name, optimizer_name); + } } // evaluate model before training (sanity check to get reasonable accuracy @@ -164,19 +236,32 @@ int main(int argc, char **argv) { LossAverageMeter loss_meter; int training_step = 0; + + auto get_loss_value = [device](const TensorPtr &loss) { + ttml::core::MeshToXTensorVariant composer = ttml::core::VectorMeshToXTensor(device->shape()); + auto loss_xtensors = ttml::core::to_xtensor(loss->get_value(), composer); + // sum of loss xtensors + float loss_float = + std::accumulate(loss_xtensors.begin(), loss_xtensors.end(), 0.0F, [](float acc, auto &xtensor) { + return acc + xtensor(0); + }); + + return loss_float / static_cast(loss_xtensors.size()); + }; + for (size_t epoch = 0; epoch < config.num_epochs; ++epoch) { for (const auto &[data, target] : train_dataloader) { optimizer.zero_grad(); - auto output = (*model)(data); + auto output = run_model(model, data); auto loss = ttml::ops::cross_entropy_loss(output, target); - auto loss_float = ttml::core::to_vector(loss->get_value())[0]; + auto loss_float = get_loss_value(loss); loss_meter.update(loss_float, config.batch_size); if (training_step % config.logging_interval == 0) { fmt::print("Step: {:5d} | Average Loss: {:.4f}\n", training_step, loss_meter.average()); } if (!config.model_path.empty() && training_step % config.model_save_interval == 0) { fmt::print("Saving model to {}\n", config.model_path); - save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); + save_model(model, config, optimizer, model_name, optimizer_name); } loss->backward(); @@ -196,7 +281,7 @@ int main(int argc, char **argv) { if (!config.model_path.empty()) { fmt::print("Saving model to {}\n", config.model_path); - save_training_state(config.model_path, model, optimizer, model_name, optimizer_name); + save_model(model, config, optimizer, model_name, optimizer_name); } return 0; diff --git a/tt-train/sources/examples/mnist_mlp/model.cpp b/tt-train/sources/examples/mnist_mlp/model.cpp new file mode 100644 index 00000000000..40d2aa4d024 --- /dev/null +++ b/tt-train/sources/examples/mnist_mlp/model.cpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "model.hpp" + +#include "ops/unary_ops.hpp" + +MNISTTensorParallel::MNISTTensorParallel() { + m_linear1 = std::make_shared( + 784, 128, /* has_bias */ true, /* gather_output */ false); + m_linear2 = std::make_shared( + 128, 10, /* has_bias */ true, /* input_is_parallel */ true); + create_name("mlp"); + register_module(m_linear1, "linear1"); + register_module(m_linear2, "linear2"); +} + +ttml::autograd::TensorPtr MNISTTensorParallel::operator()(ttml::autograd::TensorPtr tensor) { + tensor = (*m_linear1)(tensor); + tensor = ttml::ops::relu(tensor); + tensor = (*m_linear2)(tensor); + return tensor; +} diff --git a/tt-train/sources/examples/mnist_mlp/model.hpp b/tt-train/sources/examples/mnist_mlp/model.hpp new file mode 100644 index 00000000000..900b75afc99 --- /dev/null +++ b/tt-train/sources/examples/mnist_mlp/model.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "autograd/module_base.hpp" +#include "modules/distributed/linear.hpp" + +class MNISTTensorParallel : public ttml::autograd::ModuleBase { +public: + MNISTTensorParallel(); + ttml::autograd::TensorPtr operator()(ttml::autograd::TensorPtr tensor); + +private: + std::shared_ptr m_linear1; + std::shared_ptr m_linear2; +}; diff --git a/tt-train/sources/examples/mnist_mlp/utils.hpp b/tt-train/sources/examples/mnist_mlp/utils.hpp index 863cb9311eb..7af7ccc3f65 100644 --- a/tt-train/sources/examples/mnist_mlp/utils.hpp +++ b/tt-train/sources/examples/mnist_mlp/utils.hpp @@ -39,7 +39,7 @@ class Timers { template void save_training_state( - std::string &model_path, + const std::string &model_path, const std::shared_ptr &model, Optimizer &optimizer, const std::string &model_name, @@ -52,7 +52,7 @@ void save_training_state( template void load_training_state( - std::string &model_path, + const std::string &model_path, const std::shared_ptr &model, Optimizer &optimizer, const std::string &model_name, diff --git a/tt-train/sources/examples/sample_app/main.cpp b/tt-train/sources/examples/sample_app/main.cpp index e5e66603129..49f78f16a32 100644 --- a/tt-train/sources/examples/sample_app/main.cpp +++ b/tt-train/sources/examples/sample_app/main.cpp @@ -81,7 +81,7 @@ int main() { // processing tt::tt_metal::Layout::TILE); // Once created, the tensor "on host" and we must move it to the device to perform operations on it - x = x.to(device); + x = x.to_device(device); // Print the tensor to see what it looks like std::cout << "Tensot x:\n"; diff --git a/tt-train/sources/ttml/autograd/tensor.cpp b/tt-train/sources/ttml/autograd/tensor.cpp index 5e2d2a0ca1e..f54282032b4 100644 --- a/tt-train/sources/ttml/autograd/tensor.cpp +++ b/tt-train/sources/ttml/autograd/tensor.cpp @@ -150,4 +150,12 @@ const std::optional& Tensor::get_node() const { return m_node_id; } +const ttnn::Shape& Tensor::get_shape() const { + return get_value().get_logical_shape(); +} + +uint32_t Tensor::get_rank() const { + return get_shape().rank(); +} + } // namespace ttml::autograd diff --git a/tt-train/sources/ttml/autograd/tensor.hpp b/tt-train/sources/ttml/autograd/tensor.hpp index 9f4cbe72a5c..c1ab301b997 100644 --- a/tt-train/sources/ttml/autograd/tensor.hpp +++ b/tt-train/sources/ttml/autograd/tensor.hpp @@ -41,6 +41,8 @@ class Tensor : public std::enable_shared_from_this { tt::tt_metal::Tensor &get_grad(); bool get_requires_grad() const; const std::optional &get_node() const; + const ttnn::Shape &get_shape() const; + uint32_t get_rank() const; void backward(bool retain_graph = false); diff --git a/tt-train/sources/ttml/core/compute_kernel_config.cpp b/tt-train/sources/ttml/core/compute_kernel_config.cpp index 3f69e6538d0..94af110530e 100644 --- a/tt-train/sources/ttml/core/compute_kernel_config.cpp +++ b/tt-train/sources/ttml/core/compute_kernel_config.cpp @@ -17,7 +17,7 @@ ttnn::WormholeComputeKernelConfig ComputeKernelConfig::precise() { ttnn::WormholeComputeKernelConfig ComputeKernelConfig::softmax() { ttnn::WormholeComputeKernelConfig config; - config.fp32_dest_acc_en = true; + config.fp32_dest_acc_en = false; config.math_approx_mode = false; config.math_fidelity = MathFidelity::HiFi4; config.packer_l1_acc = true; diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.hpp b/tt-train/sources/ttml/core/tt_tensor_utils.hpp index 3035e7eca1e..f3a2900e080 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.hpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.hpp @@ -66,7 +66,7 @@ template template auto to_xtensor(const tt::tt_metal::Tensor& tensor, const MeshToXTensorVariant& composer) { auto cpu_tensor = tensor.cpu(); - cpu_tensor = cpu_tensor.to(Layout::ROW_MAJOR); + cpu_tensor = cpu_tensor.to_layout(Layout::ROW_MAJOR); auto cpu_tensors = ttnn::distributed::get_device_tensors(cpu_tensor); std::vector> res; res.reserve(cpu_tensors.size()); diff --git a/tt-train/sources/ttml/core/ttnn_all_includes.hpp b/tt-train/sources/ttml/core/ttnn_all_includes.hpp index 90e21b64b00..0dc4a096ea8 100644 --- a/tt-train/sources/ttml/core/ttnn_all_includes.hpp +++ b/tt-train/sources/ttml/core/ttnn_all_includes.hpp @@ -9,18 +9,18 @@ #pragma GCC diagnostic ignored "-Wdeprecated-volatile" #pragma GCC diagnostic ignored "-Wdeprecated-this-capture" -#include // NOLINT -#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT #include // NOLINT -#include // NOLINT -#include // NOLINT -#include // NOLINT -#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT +#include // NOLINT #include // NOLINT -#include // NOLINT -#include // NOLINT -#include // NOLINT -#include // NOLINT #include // NOLINT #include // NOLINT #include // NOLINT diff --git a/tt-train/sources/ttml/init/cpu_initializers.cpp b/tt-train/sources/ttml/init/cpu_initializers.cpp index b493095d951..4dbc333f4df 100644 --- a/tt-train/sources/ttml/init/cpu_initializers.cpp +++ b/tt-train/sources/ttml/init/cpu_initializers.cpp @@ -7,10 +7,17 @@ #include #include "autograd/auto_context.hpp" -#include "fmt/core.h" namespace ttml::init { +xt::xarray uniform_init(const ttnn::Shape& shape, UniformRange range) { + std::vector data(shape.volume()); + uniform_init(data, range); + std::vector shape_vec(shape.cbegin(), shape.cend()); + // adapt creates view of the vector, but return will copy this data anyway (by creation of xt::array) + return xt::adapt(std::move(data), shape_vec); +} + void uniform_init(std::vector& vec, UniformRange range) { auto& [a, b] = range; diff --git a/tt-train/sources/ttml/init/cpu_initializers.hpp b/tt-train/sources/ttml/init/cpu_initializers.hpp index 4743ba8db79..743e48689ae 100644 --- a/tt-train/sources/ttml/init/cpu_initializers.hpp +++ b/tt-train/sources/ttml/init/cpu_initializers.hpp @@ -3,8 +3,11 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once + #include +#include "core/xtensor_utils.hpp" + namespace ttml::init { struct UniformRange { @@ -22,6 +25,8 @@ struct FanParams { uint32_t fan_out = 1; }; +xt::xarray uniform_init(const ttnn::Shape& shape, UniformRange range); + void uniform_init(std::vector& vec, UniformRange range); void normal_init(std::vector& vec, NormalParams params); diff --git a/tt-train/sources/ttml/modules/distributed/linear.cpp b/tt-train/sources/ttml/modules/distributed/linear.cpp new file mode 100644 index 00000000000..0e1e22d421e --- /dev/null +++ b/tt-train/sources/ttml/modules/distributed/linear.cpp @@ -0,0 +1,125 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "linear.hpp" + +#include "autograd/auto_context.hpp" +#include "autograd/tensor.hpp" +#include "core/tt_tensor_utils.hpp" +#include "init/cpu_initializers.hpp" +#include "init/tensor_initializers.hpp" +#include "ops/binary_ops.hpp" +#include "ops/distributed/comm_ops.hpp" +#include "ops/linear_op.hpp" + +namespace ttml::modules::distributed { + +RowParallelLinear::RowParallelLinear( + uint32_t in_features, uint32_t out_features, bool has_bias, bool input_is_parallel) : + m_input_is_parallel(input_is_parallel) { + initialize_tensors(in_features, out_features, has_bias); + + create_name("row_parallel_linear"); + register_tensor(m_weight, "weight"); + if (m_bias != nullptr) { + register_tensor(m_bias, "bias"); + } +} + +autograd::TensorPtr RowParallelLinear::operator()(autograd::TensorPtr tensor) { + if (!m_input_is_parallel) { + tensor = ops::distributed::scatter(tensor, tensor->get_rank() - 1U); + } + + // do not pass bias + tensor = ops::linear_op(tensor, m_weight, /* bias */ nullptr); + tensor = ops::distributed::all_reduce(tensor); + if (m_bias != nullptr) { + tensor = ops::add(tensor, m_bias); + } + return tensor; +} + +void RowParallelLinear::initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias) { + auto* device = &autograd::ctx().get_device(); + auto num_devices = static_cast(device->num_devices()); + if (out_features % num_devices != 0) { + throw std::runtime_error(fmt::format( + "Output features must be divisible by the number of devices. Output features = {}, devices = {}", + out_features, + num_devices)); + } + + auto weight_shape = core::create_shape({1, 1, out_features, in_features}); + + uint32_t rank = 4U; + auto mesh_shape = device->shape(); + const float init_k = std::sqrtf(1.F / static_cast(in_features)); + + ttml::core::XTensorToMeshVariant shard_composer = + ttml::core::ShardXTensorToMesh(mesh_shape, rank - 1U); + auto weight = init::uniform_init(weight_shape, init::UniformRange{-init_k, init_k}); + m_weight = + autograd::create_tensor(ttml::core::from_xtensor(weight, device, shard_composer)); + + if (has_bias) { + auto bias_shape = core::create_shape({1, 1, 1, out_features}); + m_bias = ttml::autograd::create_tensor(); + init::uniform_init(m_bias, bias_shape, init::UniformRange{-init_k, init_k}); + } +} + +ColumnParallelLinear::ColumnParallelLinear( + uint32_t in_features, uint32_t out_features, bool has_bias, bool gather_output) : + m_gather_output(gather_output) { + initialize_tensors(in_features, out_features, has_bias); + + create_name("column_parallel_linear"); + register_tensor(m_weight, "weight"); + if (m_bias != nullptr) { + register_tensor(m_bias, "bias"); + } +} + +autograd::TensorPtr ColumnParallelLinear::operator()(autograd::TensorPtr tensor) { + tensor = ops::linear_op(tensor, m_weight, m_bias); + if (m_gather_output) { + tensor = ops::distributed::all_gather(tensor, tensor->get_rank() - 1U); + } + return tensor; +} + +void ColumnParallelLinear::initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias) { + auto* device = &autograd::ctx().get_device(); + auto num_devices = static_cast(device->num_devices()); + if (in_features % num_devices != 0) { + throw std::runtime_error(fmt::format( + "Input features must be divisible by the number of devices. Input features = {}, devices = {}", + in_features, + num_devices)); + } + + auto weight_shape = core::create_shape({1, 1, out_features, in_features}); + + uint32_t rank = 4U; + auto mesh_shape = device->shape(); + const float init_k = std::sqrtf(1.F / static_cast(in_features)); + + ttml::core::XTensorToMeshVariant shard_composer = + ttml::core::ShardXTensorToMesh(mesh_shape, rank - 2U); + auto weight = init::uniform_init(weight_shape, init::UniformRange{-init_k, init_k}); + m_weight = + autograd::create_tensor(ttml::core::from_xtensor(weight, device, shard_composer)); + + if (has_bias) { + auto bias_shape = core::create_shape({1, 1, 1, out_features}); + auto bias = init::uniform_init(bias_shape, init::UniformRange{-init_k, init_k}); + ttml::core::XTensorToMeshVariant shard_composer = + ttml::core::ShardXTensorToMesh(mesh_shape, rank - 1U); + m_bias = + autograd::create_tensor(ttml::core::from_xtensor(bias, device, shard_composer)); + } +} + +} // namespace ttml::modules::distributed diff --git a/tt-train/sources/ttml/modules/distributed/linear.hpp b/tt-train/sources/ttml/modules/distributed/linear.hpp new file mode 100644 index 00000000000..c7bf0795580 --- /dev/null +++ b/tt-train/sources/ttml/modules/distributed/linear.hpp @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "autograd/module_base.hpp" +#include "autograd/tensor.hpp" + +namespace ttml::modules::distributed { + +class RowParallelLinear : public autograd::ModuleBase { +public: + RowParallelLinear( + uint32_t in_features, uint32_t out_features, bool has_bias = true, bool input_is_parallel = false); + autograd::TensorPtr operator()(autograd::TensorPtr tensor); + +private: + void initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias = true); + + autograd::TensorPtr m_weight; + autograd::TensorPtr m_bias; + bool m_input_is_parallel{false}; +}; + +class ColumnParallelLinear : public autograd::ModuleBase { +public: + ColumnParallelLinear(uint32_t in_features, uint32_t out_features, bool has_bias = true, bool gather_output = false); + autograd::TensorPtr operator()(autograd::TensorPtr tensor); + +private: + void initialize_tensors(uint32_t in_features, uint32_t out_features, bool has_bias = true); + + autograd::TensorPtr m_weight; + autograd::TensorPtr m_bias; + bool m_gather_output{false}; +}; + +} // namespace ttml::modules::distributed diff --git a/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp b/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp index 81f5ad2c85a..72d3ca8e092 100644 --- a/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp +++ b/tt-train/sources/ttml/modules/multi_layer_perceptron.cpp @@ -14,6 +14,7 @@ void add_linear_layer(Layers& layers, Args&&... args) { } MultiLayerPerceptron::MultiLayerPerceptron(const MultiLayerPerceptronParameters& params) { + m_layers.reserve(params.hidden_features.size() + 1UL); uint32_t current_input_features = params.input_features; for (auto hidden_features : params.hidden_features) { add_linear_layer(m_layers, current_input_features, hidden_features); diff --git a/tt-train/sources/ttml/ops/distributed/comm_ops.cpp b/tt-train/sources/ttml/ops/distributed/comm_ops.cpp new file mode 100644 index 00000000000..278c9b6ce4b --- /dev/null +++ b/tt-train/sources/ttml/ops/distributed/comm_ops.cpp @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "comm_ops.hpp" + +#include + +#include "autograd/auto_context.hpp" +#include "autograd/graph.hpp" +#include "autograd/graph_utils.hpp" +#include "ttnn_fixed/distributed/ttnn_ops.hpp" + +namespace ttml::ops::distributed { + +autograd::TensorPtr scatter(const autograd::TensorPtr& tensor, int dim) { + auto out = autograd::create_tensor(ttnn_fixed::distributed::scatter(tensor->get_value(), dim)); + autograd::GradFunction grad = [tensor, out, dim]() { tensor->set_grad(ttnn::all_gather(out->get_grad(), dim)); }; + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} + +autograd::TensorPtr all_gather(const autograd::TensorPtr& tensor, int dim) { + auto out = autograd::create_tensor(ttnn::all_gather(tensor->get_value(), dim)); + autograd::GradFunction grad = [tensor, out, dim]() { + tensor->set_grad(ttnn_fixed::distributed::scatter(out->get_grad(), dim)); + }; + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} + +autograd::TensorPtr all_reduce(const autograd::TensorPtr& tensor) { + auto out = autograd::create_tensor(ttnn::experimental::all_reduce( + tensor->get_value(), ttnn::operations::reduction::ReduceType::Sum, 1, std::nullopt, ttnn::ccl::Topology::Ring)); + autograd::GradFunction grad = [tensor, out]() { tensor->set_grad(out->get_grad()); }; + auto links = autograd::get_links(tensor); + out->set_node(autograd::ctx().add_backward_node(std::move(grad), links)); + return out; +} + +} // namespace ttml::ops::distributed diff --git a/tt-train/sources/ttml/ops/distributed/comm_ops.hpp b/tt-train/sources/ttml/ops/distributed/comm_ops.hpp new file mode 100644 index 00000000000..4e051fbc1bc --- /dev/null +++ b/tt-train/sources/ttml/ops/distributed/comm_ops.hpp @@ -0,0 +1,15 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "autograd/tensor.hpp" + +namespace ttml::ops::distributed { + +autograd::TensorPtr scatter(const autograd::TensorPtr& tensor, int dim); +autograd::TensorPtr all_reduce(const autograd::TensorPtr& tensor); +autograd::TensorPtr all_gather(const autograd::TensorPtr& tensor, int dim); + +} // namespace ttml::ops::distributed diff --git a/tt-train/sources/ttml/ops/unary_ops.cpp b/tt-train/sources/ttml/ops/unary_ops.cpp index e2e76fb881c..a9ec11094eb 100644 --- a/tt-train/sources/ttml/ops/unary_ops.cpp +++ b/tt-train/sources/ttml/ops/unary_ops.cpp @@ -72,7 +72,7 @@ autograd::TensorPtr log_softmax_moreh(const autograd::TensorPtr& tensor, int dim ttnn::operations::moreh::moreh_softmax::MorehSoftmaxOp::LOGSOFTMAX, ttnn::operations::moreh::moreh_softmax::MorehSoftmaxOpParallelizationStrategy::NONE, /* output_mem_config */ std::nullopt, - /* compute_kernel_config */ core::ComputeKernelConfig::precise()); + /* compute_kernel_config */ core::ComputeKernelConfig::softmax()); auto out = autograd::create_tensor(log_softmax); autograd::GradFunction grad = [tensor, out, dim]() { diff --git a/tt-train/tests/modules/distributed/linear_test.cpp b/tt-train/tests/modules/distributed/linear_test.cpp new file mode 100644 index 00000000000..39fc1c587f3 --- /dev/null +++ b/tt-train/tests/modules/distributed/linear_test.cpp @@ -0,0 +1,368 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "modules/distributed/linear.hpp" + +#include + +#include +#include + +#include "autograd/auto_context.hpp" +#include "core/distributed_mapping.hpp" +#include "core/tt_tensor_utils.hpp" + +namespace { + +auto check_board_is_n300() { + return tt::Cluster::instance().get_board_type(0) == BoardType::N300; +} + +ttml::autograd::TensorPtr get_parameter(auto& parameters, const std::string& name_substring) { + for (const auto& [name, parameter] : parameters) { + if (name.find(name_substring) != std::string::npos) { + return parameter; + } + } + throw std::logic_error(fmt::format("Parameter for a given name substring {} not found", name_substring)); +} + +} // namespace + +class N300TensorParallelLinearTest : public ::testing::Test { +protected: + void SetUp() override { + if (!check_board_is_n300()) { + GTEST_SKIP() << "Skipping N300 specific tests"; + } + ttml::autograd::ctx().set_mesh_shape({1, 2}); + ttml::autograd::ctx().open_device(); + } + + void TearDown() override { + ttml::autograd::ctx().close_device(); + } +}; + +TEST_F(N300TensorParallelLinearTest, RowParallelLinearHasBiasNotInputParallel) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = true; + bool input_is_parallel = false; + + auto layer = ttml::modules::distributed::RowParallelLinear(in_features, out_features, has_bias, input_is_parallel); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + auto bias = get_parameter(parameters, "bias"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + ttml::core::MeshToXTensorVariant concat_composer = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer); + auto bias_xtensor = ttml::core::to_xtensor(bias->get_value(), identity_composer); + + auto weight_xtensor_shape = weight_xtensor[0].shape(); + auto test_data_shape = test_data.shape(); + + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + if (has_bias) { + expected_output += bias_xtensor[0]; + } + + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, RowParallelLinearNoBiasNotInputParallel) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = false; + bool input_is_parallel = false; + + auto layer = ttml::modules::distributed::RowParallelLinear(in_features, out_features, has_bias, input_is_parallel); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + ttml::core::MeshToXTensorVariant concat_composer = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer); + + auto weight_xtensor_shape = weight_xtensor[0].shape(); + auto test_data_shape = test_data.shape(); + + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, RowParallelLinearHasBiasInputParallel) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = true; + bool input_is_parallel = true; + + auto layer = ttml::modules::distributed::RowParallelLinear(in_features, out_features, has_bias, input_is_parallel); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + auto bias = get_parameter(parameters, "bias"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, shard_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + ttml::core::MeshToXTensorVariant concat_composer = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer); + auto bias_xtensor = ttml::core::to_xtensor(bias->get_value(), identity_composer); + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + if (has_bias) { + expected_output += bias_xtensor[0]; + } + + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, RowParallelLinearNoBiasInputParallel) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = false; + bool input_is_parallel = true; + + auto layer = ttml::modules::distributed::RowParallelLinear(in_features, out_features, has_bias, input_is_parallel); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, shard_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + ttml::core::MeshToXTensorVariant concat_composer = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer); + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearHasBiasAllGather) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = true; + bool use_all_gather = true; + + auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + auto bias = get_parameter(parameters, "bias"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + ttml::core::MeshToXTensorVariant concat_composer_2 = ttml::core::ConcatMeshToXTensor(mesh_shape, 2U); + ttml::core::MeshToXTensorVariant concat_composer_3 = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer_2); + auto bias_xtensor = ttml::core::to_xtensor(bias->get_value(), concat_composer_3); + + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + if (has_bias) { + expected_output += bias_xtensor[0]; + } + + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-2, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-2, /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearNoBiasAllGather) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = false; + bool use_all_gather = true; + + auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(output_xtensor[0], output_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + ttml::core::MeshToXTensorVariant concat_composer_2 = ttml::core::ConcatMeshToXTensor(mesh_shape, 2U); + ttml::core::MeshToXTensorVariant concat_composer_3 = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer_2); + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[0], /* rtol */ 1e-2, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(expected_output, output_xtensor[1], /* rtol */ 1e-2, /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearHasBiasNoAllGather) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = true; + bool use_all_gather = false; + + auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + auto bias = get_parameter(parameters, "bias"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + + ttml::core::MeshToXTensorVariant concat_composer_2 = ttml::core::ConcatMeshToXTensor(mesh_shape, 2U); + ttml::core::MeshToXTensorVariant concat_composer_3 = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer_2); + auto bias_xtensor = ttml::core::to_xtensor(bias->get_value(), concat_composer_3); + + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + expected_output = expected_output.reshape({1U, 1U, 1U, out_features}); + if (has_bias) { + expected_output += bias_xtensor[0]; + } + + EXPECT_TRUE(xt::allclose( + xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(0, out_features / 2)), + output_xtensor[0], + /* rtol */ 1e-2, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(out_features / 2, out_features)), + output_xtensor[1], + /* rtol */ 1e-2, + /* atol */ 1e-2)); +}; + +TEST_F(N300TensorParallelLinearTest, ColumnParallelLinearNoBiasNoAllGather) { + uint32_t in_features = 64U; + uint32_t out_features = 64U; + bool has_bias = false; + bool use_all_gather = false; + + auto layer = ttml::modules::distributed::ColumnParallelLinear(in_features, out_features, has_bias, use_all_gather); + auto parameters = layer.parameters(); + EXPECT_EQ(parameters.size(), 1UL + static_cast(has_bias)); + + auto weight = get_parameter(parameters, "weight"); + + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + xt::xarray test_data = xt::random::rand({in_features}, 0.F, 1.F).reshape({1U, 1U, 1U, in_features}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(test_data, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto output = layer(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto output_xtensor = ttml::core::to_xtensor(output->get_value(), identity_composer); + + ttml::core::MeshToXTensorVariant concat_composer_2 = ttml::core::ConcatMeshToXTensor(mesh_shape, 2U); + ttml::core::MeshToXTensorVariant concat_composer_3 = ttml::core::ConcatMeshToXTensor(mesh_shape, 3U); + // (1, 1, out_features, in_features) + auto weight_xtensor = ttml::core::to_xtensor(weight->get_value(), concat_composer_2); + + auto expected_output = xt::linalg::dot(test_data, xt::transpose(weight_xtensor[0], {0, 1, 3, 2})); + expected_output = expected_output.reshape({1U, 1U, 1U, out_features}); + + EXPECT_TRUE(xt::allclose( + xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(0, out_features / 2)), + output_xtensor[0], + /* rtol */ 1e-2, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + xt::view(expected_output, xt::all(), xt::all(), xt::all(), xt::range(out_features / 2, out_features)), + output_xtensor[1], + /* rtol */ 1e-2, + /* atol */ 1e-2)); +}; diff --git a/tt-train/tests/ops/distributed/comm_ops_test.cpp b/tt-train/tests/ops/distributed/comm_ops_test.cpp new file mode 100644 index 00000000000..e9ca096998e --- /dev/null +++ b/tt-train/tests/ops/distributed/comm_ops_test.cpp @@ -0,0 +1,313 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ops/distributed/comm_ops.hpp" + +#include + +#include +#include + +#include "autograd/auto_context.hpp" +#include "core/distributed_mapping.hpp" +#include "core/tt_tensor_utils.hpp" +#include "init/cpu_initializers.hpp" + +namespace { + +auto check_board_is_n300() { + return tt::Cluster::instance().get_board_type(0) == BoardType::N300; +} + +} // namespace + +class N300CommOpsTest : public ::testing::Test { +protected: + void SetUp() override { + if (!check_board_is_n300()) { + GTEST_SKIP() << "Skipping N300 specific tests"; + } + ttml::autograd::ctx().set_mesh_shape({1, 2}); + ttml::autograd::ctx().open_device(); + } + + void TearDown() override { + ttml::autograd::ctx().close_device(); + } +}; + +TEST_F(N300CommOpsTest, TestAllReduceNotFullyTiled) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + size_t size = 64UL; + std::vector test_data_vec(size); + std::iota(test_data_vec.begin(), test_data_vec.end(), 0.0F); + xt::xarray test_data = xt::adapt(test_data_vec); + xt::xarray xtensor = test_data.reshape({1U, 1U, 1U, size}); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_tensor = ttml::core::from_xtensor(xtensor, device, shard_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto all_reduce_tensor = ttml::ops::distributed::all_reduce(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto all_reduce_xtensor = ttml::core::to_xtensor(all_reduce_tensor->get_value(), identity_composer); + + xt::xarray all_reduce_expected = + xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(0, size / 2)) + + xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(size / 2, size)); + + EXPECT_TRUE(xt::allclose(all_reduce_expected, all_reduce_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(all_reduce_expected, all_reduce_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + xt::xarray grad_data = xt::random::rand(all_reduce_expected.shape(), 0.F, 1.F); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_grad_tensor = ttml::core::from_xtensor(grad_data, device, replicate_composer); + all_reduce_tensor->set_grad(tt_grad_tensor); + all_reduce_tensor->backward(); + + auto result_tensor_grad = tensor->get_grad(); + EXPECT_TRUE(ttml::core::is_tensor_initialized(result_tensor_grad)); + + auto grad_xtensor = ttml::core::to_xtensor(tensor->get_grad(), identity_composer); + EXPECT_EQ(grad_xtensor[0].shape(), grad_xtensor[1].shape()); + EXPECT_TRUE(xt::allclose( + grad_data, + grad_xtensor[0], + /* rtol */ 1e-3, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + grad_data, + grad_xtensor[1], + /* rtol */ 1e-3, + /* atol */ 1e-2)); +} + +TEST_F(N300CommOpsTest, TestAllReduceFullyTiled) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + size_t size = 64UL; + size_t height = 32UL; + std::vector test_data_vec(size * height); + ttml::init::uniform_init(test_data_vec, {0.F, 0.001F}); + xt::xarray test_data = xt::adapt(test_data_vec); + xt::xarray xtensor = test_data.reshape({1U, 1U, height, size}); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_tensor = ttml::core::from_xtensor(xtensor, device, shard_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto all_reduce_tensor = ttml::ops::distributed::all_reduce(tensor); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto all_reduce_xtensor = ttml::core::to_xtensor(all_reduce_tensor->get_value(), identity_composer); + + xt::xarray all_reduce_expected = + xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(0, size / 2)) + + xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(size / 2, size)); + + EXPECT_TRUE(xt::allclose(all_reduce_expected, all_reduce_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(all_reduce_expected, all_reduce_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + xt::xarray grad_data = xt::random::rand(all_reduce_expected.shape(), 0.F, 1.F); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_grad_tensor = ttml::core::from_xtensor(grad_data, device, replicate_composer); + all_reduce_tensor->set_grad(tt_grad_tensor); + all_reduce_tensor->backward(); + + auto result_tensor_grad = tensor->get_grad(); + EXPECT_TRUE(ttml::core::is_tensor_initialized(result_tensor_grad)); + + auto grad_xtensor = ttml::core::to_xtensor(tensor->get_grad(), identity_composer); + EXPECT_EQ(grad_xtensor[0].shape(), grad_xtensor[1].shape()); + EXPECT_TRUE(xt::allclose( + grad_data, + grad_xtensor[0], + /* rtol */ 1e-3, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + grad_data, + grad_xtensor[1], + /* rtol */ 1e-3, + /* atol */ 1e-2)); +} + +TEST_F(N300CommOpsTest, TestAllGatherNotFullyTiled) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + size_t size = 64UL; + std::vector test_data_vec(size); + std::iota(test_data_vec.begin(), test_data_vec.end(), 0.0F); + xt::xarray test_data = xt::adapt(test_data_vec); + xt::xarray xtensor = test_data.reshape({1U, 1U, 1U, size}); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_tensor = ttml::core::from_xtensor(xtensor, device, shard_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto gathered_tensor = ttml::ops::distributed::all_gather(tensor, 3); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto gathered_xtensor = ttml::core::to_xtensor(gathered_tensor->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(xtensor, gathered_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(xtensor, gathered_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + xt::xarray grad_data = xt::random::rand(xtensor.shape(), 0.F, 1.F); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_grad_tensor = ttml::core::from_xtensor(grad_data, device, replicate_composer); + gathered_tensor->set_grad(tt_grad_tensor); + gathered_tensor->backward(); + + auto result_tensor_grad = tensor->get_grad(); + EXPECT_TRUE(ttml::core::is_tensor_initialized(result_tensor_grad)); + + auto grad_xtensor = ttml::core::to_xtensor(tensor->get_grad(), identity_composer); + EXPECT_EQ(grad_xtensor[0].shape(), grad_xtensor[1].shape()); + EXPECT_TRUE(xt::allclose( + xt::view(grad_data, xt::all(), xt::all(), xt::all(), xt::range(0, size / 2)), + grad_xtensor[0], + /* rtol */ 1e-3, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + xt::view(grad_data, xt::all(), xt::all(), xt::all(), xt::range(size / 2, size)), + grad_xtensor[1], + /* rtol */ 1e-3, + /* atol */ 1e-2)); +} + +TEST_F(N300CommOpsTest, TestAllGatherFullyTiled) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + size_t size = 64UL; + size_t height = 32UL; + std::vector test_data_vec(size * height); + ttml::init::uniform_init(test_data_vec, {0.F, 0.001F}); + xt::xarray test_data = xt::adapt(test_data_vec); + xt::xarray xtensor = test_data.reshape({1U, 1U, height, size}); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_tensor = ttml::core::from_xtensor(xtensor, device, shard_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto gathered_tensor = ttml::ops::distributed::all_gather(tensor, 3); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto gathered_xtensor = ttml::core::to_xtensor(gathered_tensor->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose(xtensor, gathered_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(xtensor, gathered_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + xt::xarray grad_data = xt::random::rand(xtensor.shape(), 0.F, 1.F); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_grad_tensor = ttml::core::from_xtensor(grad_data, device, replicate_composer); + gathered_tensor->set_grad(tt_grad_tensor); + gathered_tensor->backward(); + + auto result_tensor_grad = tensor->get_grad(); + EXPECT_TRUE(ttml::core::is_tensor_initialized(result_tensor_grad)); + + auto grad_xtensor = ttml::core::to_xtensor(tensor->get_grad(), identity_composer); + EXPECT_EQ(grad_xtensor[0].shape(), grad_xtensor[1].shape()); + EXPECT_TRUE(xt::allclose( + xt::view(grad_data, xt::all(), xt::all(), xt::all(), xt::range(0, size / 2)), + grad_xtensor[0], + /* rtol */ 1e-3, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + xt::view(grad_data, xt::all(), xt::all(), xt::all(), xt::range(size / 2, size)), + grad_xtensor[1], + /* rtol */ 1e-3, + /* atol */ 1e-2)); +} + +TEST_F(N300CommOpsTest, TestScatterNotFullyTiled) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + size_t size = 64UL; + std::vector test_data_vec(size); + std::iota(test_data_vec.begin(), test_data_vec.end(), 0.0F); + xt::xarray test_data = xt::adapt(test_data_vec); + xt::xarray xtensor = test_data.reshape({1U, 1U, 1U, size}); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(xtensor, device, replicate_composer); + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto scattered_tensor = ttml::ops::distributed::scatter(tensor, 3); + + // check forward + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + auto xtensors_back = ttml::core::to_xtensor(scattered_tensor->get_value(), identity_composer); + EXPECT_TRUE( + xt::allclose(xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(0, size / 2)), xtensors_back[0])); + EXPECT_TRUE( + xt::allclose(xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(size / 2, size)), xtensors_back[1])); + + // check backward + xt::xarray grad_data = xt::random::rand(xtensor.shape(), 0.F, 1.F); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_grad_tensor = ttml::core::from_xtensor(grad_data, device, shard_composer); + scattered_tensor->set_grad(tt_grad_tensor); + scattered_tensor->backward(); + + auto result_tensor_grad = tensor->get_grad(); + EXPECT_TRUE(ttml::core::is_tensor_initialized(result_tensor_grad)); + + auto grad_xtensor = ttml::core::to_xtensor(tensor->get_grad(), identity_composer); + EXPECT_TRUE(grad_data.shape() == grad_xtensor[0].shape()); + EXPECT_TRUE(grad_data.shape() == grad_xtensor[1].shape()); + + EXPECT_EQ(grad_xtensor[0], grad_xtensor[1]); + EXPECT_TRUE(xt::allclose(grad_data, grad_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(grad_data, grad_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); +} + +TEST_F(N300CommOpsTest, TestScatterFullyTiled) { + auto* device = &ttml::autograd::ctx().get_device(); + auto mesh_shape = device->shape(); + + size_t size = 64UL; + size_t height = 32UL; + std::vector test_data_vec(size * height); + ttml::init::uniform_init(test_data_vec, {0.F, 0.001F}); + xt::xarray test_data = xt::adapt(test_data_vec); + xt::xarray xtensor = test_data.reshape({1U, 1U, height, size}); + + ttml::core::MeshToXTensorVariant identity_composer = ttml::core::VectorMeshToXTensor(mesh_shape); + ttml::core::XTensorToMeshVariant replicate_composer = ttml::core::ReplicateXTensorToMesh(mesh_shape); + auto tt_tensor = ttml::core::from_xtensor(xtensor, device, replicate_composer); + + auto xtensor_after_replication = ttml::core::to_xtensor(tt_tensor, identity_composer); + EXPECT_TRUE(xt::allclose(xtensor, xtensor_after_replication[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(xtensor, xtensor_after_replication[1], /* rtol */ 1e-3, /* atol */ 1e-2)); + + auto tensor = ttml::autograd::create_tensor(tt_tensor); + auto scattered_tensor = ttml::ops::distributed::scatter(tensor, 3); + + // check forward + auto xtensors_back = ttml::core::to_xtensor(scattered_tensor->get_value(), identity_composer); + EXPECT_TRUE(xt::allclose( + xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(0, size / 2)), + xtensors_back[0], + /* rtol */ 1e-3, + /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose( + xt::view(xtensor, xt::all(), xt::all(), xt::all(), xt::range(size / 2, size)), + xtensors_back[1], + /* rtol */ 1e-3, + /* atol */ 1e-2)); + + // check backward + xt::xarray grad_data = xt::random::rand(xtensor.shape(), 0.F, 0.001F); + ttml::core::XTensorToMeshVariant shard_composer = ttml::core::ShardXTensorToMesh(mesh_shape, 3); + auto tt_grad_tensor = ttml::core::from_xtensor(grad_data, device, shard_composer); + scattered_tensor->set_grad(tt_grad_tensor); + scattered_tensor->backward(); + + auto result_tensor_grad = tensor->get_grad(); + EXPECT_TRUE(ttml::core::is_tensor_initialized(result_tensor_grad)); + + auto grad_xtensor = ttml::core::to_xtensor(tensor->get_grad(), identity_composer); + EXPECT_TRUE(grad_data.shape() == grad_xtensor[0].shape()); + EXPECT_TRUE(grad_data.shape() == grad_xtensor[1].shape()); + + EXPECT_EQ(grad_xtensor[0], grad_xtensor[1]); + EXPECT_TRUE(xt::allclose(grad_data, grad_xtensor[0], /* rtol */ 1e-3, /* atol */ 1e-2)); + EXPECT_TRUE(xt::allclose(grad_data, grad_xtensor[1], /* rtol */ 1e-3, /* atol */ 1e-2)); +} diff --git a/tt_fabric/control_plane.cpp b/tt_fabric/control_plane.cpp index 5fb284c7a64..d57cc6b884d 100644 --- a/tt_fabric/control_plane.cpp +++ b/tt_fabric/control_plane.cpp @@ -484,6 +484,7 @@ void ControlPlane::write_routing_tables_to_chip(mesh_id_t mesh_id, chip_id_t chi tt_metal::hal.get_dev_addr( tt_metal::HalProgrammableCoreType::ACTIVE_ETH, tt_metal::HalL1MemAddrType::FABRIC_ROUTER_CONFIG), false); + tt::Cluster::instance().l1_barrier(physical_chip_id); } } } @@ -589,6 +590,17 @@ std::vector> ControlPlane::get_fabric_route( return route; } +std::vector ControlPlane::get_intra_chip_neighbors( + mesh_id_t src_mesh_id, chip_id_t src_chip_id, RoutingDirection routing_direction) const { + for (const auto& [_, routing_edge] : + this->routing_table_generator_->get_intra_mesh_connectivity()[src_mesh_id][src_chip_id]) { + if (routing_edge.port_direction == routing_direction) { + return routing_edge.connected_chip_ids; + } + } + return {}; +} + void ControlPlane::configure_routing_tables() const { // Configure the routing tables on the chips TT_ASSERT( diff --git a/tt_fabric/control_plane.hpp b/tt_fabric/control_plane.hpp index e9faa1377c3..7c829b7ea3c 100644 --- a/tt_fabric/control_plane.hpp +++ b/tt_fabric/control_plane.hpp @@ -43,6 +43,9 @@ class ControlPlane { chip_id_t dst_chip_id, chan_id_t src_chan_id) const; + std::vector get_intra_chip_neighbors( + mesh_id_t src_mesh_id, chip_id_t src_chip_id, RoutingDirection routing_direction) const; + private: std::unique_ptr routing_table_generator_; std::vector> logical_mesh_chip_id_to_physical_chip_id_mapping_; diff --git a/tt_fabric/hw/inc/tt_fabric.h b/tt_fabric/hw/inc/tt_fabric.h index 7151390998b..c84ba88094a 100644 --- a/tt_fabric/hw/inc/tt_fabric.h +++ b/tt_fabric/hw/inc/tt_fabric.h @@ -29,7 +29,7 @@ extern volatile chan_payload_ptr remote_rdptr; uint64_t tt_fabric_send_pull_request(uint64_t dest_addr, volatile local_pull_request_t* local_pull_request); uint32_t num_words_available_to_pull(volatile pull_request_t* pull_request); -uint32_t words_before_buffer_wrap(uint32_t buffer_size, uint32_t rd_ptr); +uint32_t words_before_pull_buffer_wrap(uint32_t buffer_size, uint32_t rd_ptr); uint32_t advance_ptr(uint32_t buffer_size, uint32_t ptr, uint32_t inc_words); uint32_t get_rd_ptr_offset_words(pull_request_t* pull_request); @@ -40,7 +40,6 @@ inline uint64_t get_timestamp() { } typedef struct fvc_consumer_state { - uint32_t remote_ptr_update_addr; uint8_t chan_num; uint8_t pad[3]; uint32_t packet_in_progress; @@ -53,28 +52,26 @@ typedef struct fvc_consumer_state { uint32_t remote_buffer_start; uint32_t pull_words_in_flight; uint32_t total_words_to_forward; - uint32_t* words_sent_remote_update; + uint32_t words_sent_remote_update; + volatile uint32_t* sender_buffer_space; + volatile uint32_t* update_sender_buffer_space; + volatile uint32_t* receiver_buffer_space; + volatile uint32_t* update_receiver_buffer_space; - uint32_t get_num_words_free() { - uint32_t rd_ptr = remote_rdptr.ptr; - uint32_t words_occupied = 0; - if (fvc_pull_wrptr != rd_ptr) { - words_occupied = - fvc_pull_wrptr > rd_ptr ? fvc_pull_wrptr - rd_ptr : buffer_size * 2 + fvc_pull_wrptr - rd_ptr; - } - return buffer_size - words_occupied; - } + FORCE_INLINE uint32_t get_num_words_free() { return *sender_buffer_space; } - uint32_t get_remote_num_words_free() { - uint32_t rd_ptr = remote_rdptr.ptr_cleared; - uint32_t words_occupied = 0; - if (fvc_out_rdptr != rd_ptr) { - words_occupied = fvc_out_rdptr > rd_ptr ? fvc_out_rdptr - rd_ptr : buffer_size * 2 + fvc_out_rdptr - rd_ptr; - } - return buffer_size - words_occupied; + FORCE_INLINE uint32_t get_remote_num_words_free() { return *receiver_buffer_space; } + + inline void reset_buffer_space(uint32_t buf_size_words) { + // Setting STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX resets the credit register + volatile uint32_t* ptr = + reinterpret_cast(STREAM_REG_ADDR(1, STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX)); + *ptr = buf_size_words; + ptr = reinterpret_cast(STREAM_REG_ADDR(2, STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX)); + *ptr = buf_size_words; } - inline void init(uint32_t data_buf_start, uint32_t data_buf_size_words, uint32_t ptr_update_addr) { + inline void init(uint32_t data_buf_start, uint32_t data_buf_size_words) { uint32_t words = sizeof(fvc_consumer_state) / 4; uint32_t* ptr = (uint32_t*)this; for (uint32_t i = 0; i < words; i++) { @@ -85,70 +82,55 @@ typedef struct fvc_consumer_state { buffer_size = data_buf_size_words; buffer_size_2x = data_buf_size_words * 2; remote_buffer_start = data_buf_start + buffer_size * PACKET_WORD_SIZE_BYTES; - remote_ptr_update_addr = ptr_update_addr; - words_sent_remote_update = - reinterpret_cast(STREAM_REG_ADDR(0, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + words_sent_remote_update = (STREAM_REG_ADDR(0, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + sender_buffer_space = + reinterpret_cast(STREAM_REG_ADDR(1, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX)); + update_sender_buffer_space = + reinterpret_cast(STREAM_REG_ADDR(1, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + receiver_buffer_space = + reinterpret_cast(STREAM_REG_ADDR(2, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX)); + update_receiver_buffer_space = + reinterpret_cast(STREAM_REG_ADDR(2, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + reset_buffer_space(data_buf_size_words); } - inline uint32_t words_before_buffer_wrap(uint32_t ptr) { - if (ptr >= buffer_size) { - return buffer_size_2x - ptr; - } else { - return buffer_size - ptr; - } - } + FORCE_INLINE uint32_t words_before_buffer_wrap(uint32_t ptr) { return buffer_size - ptr; } - inline uint32_t words_before_local_buffer_wrap() { - if (fvc_pull_wrptr >= buffer_size) { - return buffer_size_2x - fvc_pull_wrptr; - } else { - return buffer_size - fvc_pull_wrptr; - } - } + FORCE_INLINE uint32_t words_before_local_buffer_wrap() { return buffer_size - fvc_pull_wrptr; } - inline uint32_t get_local_buffer_pull_addr() { - uint32_t addr = buffer_start; - uint32_t offset = fvc_pull_wrptr; - if (offset >= buffer_size) { - offset -= buffer_size; - } - addr = addr + (offset * PACKET_WORD_SIZE_BYTES); - return addr; + FORCE_INLINE uint32_t get_local_buffer_pull_addr() { + return buffer_start + (fvc_pull_wrptr * PACKET_WORD_SIZE_BYTES); } - inline uint32_t get_local_buffer_read_addr() { - uint32_t addr = buffer_start; - uint32_t offset = fvc_out_rdptr; - if (offset >= buffer_size) { - offset -= buffer_size; - } - addr = addr + (offset * PACKET_WORD_SIZE_BYTES); - return addr; + FORCE_INLINE uint32_t get_local_buffer_read_addr() { + return buffer_start + (fvc_out_rdptr * PACKET_WORD_SIZE_BYTES); } - inline void advance_pull_wrptr(uint32_t num_words) { + FORCE_INLINE void advance_pull_wrptr(uint32_t num_words) { uint32_t temp = fvc_pull_wrptr + num_words; - if (temp >= buffer_size_2x) { - temp -= buffer_size_2x; + if (temp >= buffer_size) { + temp -= buffer_size; } fvc_pull_wrptr = temp; + *update_sender_buffer_space = (-num_words) << REMOTE_DEST_BUF_WORDS_FREE_INC; } - inline void advance_out_rdptr(uint32_t num_words) { + FORCE_INLINE void advance_out_rdptr(uint32_t num_words) { uint32_t temp = fvc_out_rdptr + num_words; - if (temp >= buffer_size_2x) { - temp -= buffer_size_2x; + if (temp >= buffer_size) { + temp -= buffer_size; } fvc_out_rdptr = temp; + *update_receiver_buffer_space = (-num_words) << REMOTE_DEST_BUF_WORDS_FREE_INC; } - inline void register_pull_data(uint32_t num_words_to_pull) { + FORCE_INLINE void register_pull_data(uint32_t num_words_to_pull) { pull_words_in_flight += num_words_to_pull; advance_pull_wrptr(num_words_to_pull); packet_words_remaining -= num_words_to_pull; } - inline void register_move_data(uint32_t num_words_to_move) { + FORCE_INLINE void register_move_data(uint32_t num_words_to_move) { advance_pull_wrptr(num_words_to_move); packet_words_remaining -= num_words_to_move; total_words_to_forward += num_words_to_move; @@ -196,6 +178,103 @@ typedef struct fvc_consumer_state { return words_to_forward; } + + inline uint32_t get_num_words_to_pull(volatile pull_request_t* pull_request) { + uint32_t num_words_to_pull = num_words_available_to_pull(pull_request); + uint32_t num_words_before_wrap = words_before_pull_buffer_wrap(pull_request->buffer_size, pull_request->rd_ptr); + + num_words_to_pull = std::min(num_words_to_pull, num_words_before_wrap); + uint32_t fvc_buffer_space = get_num_words_free(); + num_words_to_pull = std::min(num_words_to_pull, fvc_buffer_space); + + if (num_words_to_pull == 0) { + return 0; + } + + uint32_t fvc_space_before_wptr_wrap = words_before_local_buffer_wrap(); + num_words_to_pull = std::min(num_words_to_pull, fvc_space_before_wptr_wrap); + + num_words_to_pull = std::min(num_words_to_pull, buffer_size / 2); + + return num_words_to_pull; + } + + FORCE_INLINE uint32_t pull_data_to_fvc_buffer(volatile pull_request_t* pull_request) { + if (packet_in_progress == 0) { + uint32_t size = pull_request->size; + packet_words_remaining = (size + PACKET_WORD_SIZE_BYTES - 1) >> 4; + packet_in_progress = 1; + } + + uint32_t num_words_to_pull = get_num_words_to_pull(pull_request); + if (num_words_to_pull == 0) { + return 0; + } + + uint32_t rd_offset = get_rd_ptr_offset_words((pull_request_t*)pull_request); + uint64_t src_addr = pull_request->buffer_start + (rd_offset * PACKET_WORD_SIZE_BYTES); + uint32_t fvc_addr = get_local_buffer_pull_addr(); + + // pull_data_from_remote(); + noc_async_read(src_addr, fvc_addr, num_words_to_pull * PACKET_WORD_SIZE_BYTES); + register_pull_data(num_words_to_pull); + pull_request->rd_ptr = advance_ptr(pull_request->buffer_size, pull_request->rd_ptr, num_words_to_pull); + pull_request->words_read += num_words_to_pull; + + return num_words_to_pull; + } + + inline uint32_t move_data_to_fvc_buffer(volatile pull_request_t* pull_request) { + if (packet_in_progress == 0) { + packet_words_remaining = PACKET_HEADER_SIZE_WORDS; + packet_in_progress = 1; + } + + // if fvc does not have enough space, try again later. + if (get_num_words_free() < PACKET_HEADER_SIZE_WORDS) { + return 0; + } + + uint32_t fvc_space_before_wptr_wrap = words_before_local_buffer_wrap(); + uint32_t* fvc_addr = (uint32_t*)get_local_buffer_pull_addr(); + uint32_t* src = (uint32_t*)pull_request; + + switch (fvc_space_before_wptr_wrap) { + case 1: + fvc_addr[0] = src[0]; + fvc_addr[1] = src[1]; + fvc_addr[2] = src[2]; + fvc_addr[3] = src[3]; + fvc_addr = (uint32_t*)buffer_start; + fvc_addr[0] = src[4]; + fvc_addr[1] = src[5]; + fvc_addr[2] = src[6]; + fvc_addr[3] = src[7]; + fvc_addr[4] = src[8]; + fvc_addr[5] = src[9]; + fvc_addr[6] = src[10]; + fvc_addr[7] = src[11]; + break; + case 2: + // uint32_t i = 0; + for (uint32_t i = 0; i < (PACKET_HEADER_SIZE_WORDS - 1) * PACKET_WORD_SIZE_BYTES / 4; i++) { + fvc_addr[i] = src[i]; + } + fvc_addr = (uint32_t*)buffer_start; + fvc_addr[0] = src[8]; + fvc_addr[1] = src[9]; + fvc_addr[2] = src[10]; + fvc_addr[3] = src[11]; + break; + default: + for (uint32_t i = 0; i < PACKET_HEADER_SIZE_BYTES / 4; i++) { + fvc_addr[i] = src[i]; + } + } + + register_move_data(PACKET_HEADER_SIZE_WORDS); + return PACKET_HEADER_SIZE_WORDS; + } } fvc_consumer_state_t; static_assert(sizeof(fvc_consumer_state_t) % 4 == 0); @@ -203,6 +282,11 @@ static_assert(sizeof(fvc_consumer_state_t) % 4 == 0); #define FVC_MODE_ROUTER 1 #define FVC_MODE_ENDPOINT 2 +enum ProcessingFlags : uint8_t { + UCAST_DEST = 1, + MCAST_DEST = 2, + NOT_DEST = 3, +}; // FVC Producer holds data that needs to be forwarded to other destinations. // This producer receives data over ethernet from neighboring chip. // Data in the producer is either destined for local chip, or has to make a noc hop @@ -215,17 +299,17 @@ static_assert(sizeof(fvc_consumer_state_t) % 4 == 0); typedef struct fvc_producer_state { chan_payload_ptr inbound_wrptr; chan_payload_ptr inbound_rdptr; - uint32_t remote_ptr_update_addr; uint32_t my_id; uint8_t chan_num; uint8_t packet_in_progress; - uint8_t packet_end_flush; - uint8_t pad2; + uint8_t packet_processing_flags; + uint8_t mcast_direction; + uint32_t mcast_router_noc_xy; uint32_t words_inbound; + uint32_t words_cleared; uint32_t packet_words_remaining; - uint32_t fvc_out_wrptr; + uint32_t hop_words_remaining; uint32_t fvc_out_rdptr; - uint32_t fvc_pull_rdptr; uint32_t buffer_size; uint32_t buffer_size_2x; uint32_t buffer_start; @@ -236,10 +320,13 @@ typedef struct fvc_producer_state { bool packet_corrupted; uint64_t packet_timestamp; uint64_t packet_dest; + uint64_t hop_dest; packet_header_t current_packet_header; uint32_t* packet_id; volatile uint32_t* words_received; uint32_t* words_received_local_update; + uint32_t update_sender_buffer_space; + uint32_t update_receiver_buffer_space; inline void reset_words_received() { // Setting STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX resets the credit register @@ -248,7 +335,7 @@ typedef struct fvc_producer_state { *ptr = 0; } - inline void init(uint32_t data_buf_start, uint32_t data_buf_size_words, uint32_t ptr_update_addr) { + inline void init(uint32_t data_buf_start, uint32_t data_buf_size_words) { uint32_t words = sizeof(fvc_producer_state) / 4; uint32_t* ptr = (uint32_t*)this; for (uint32_t i = 0; i < words; i++) { @@ -259,64 +346,83 @@ typedef struct fvc_producer_state { buffer_start = data_buf_start; buffer_size = data_buf_size_words; buffer_size_2x = data_buf_size_words * 2; - remote_ptr_update_addr = ptr_update_addr; words_received = reinterpret_cast(STREAM_REG_ADDR(0, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX)); words_received_local_update = reinterpret_cast(STREAM_REG_ADDR(0, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + update_sender_buffer_space = (STREAM_REG_ADDR(1, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + update_receiver_buffer_space = (STREAM_REG_ADDR(2, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX)); + reset_words_received(); packet_id = (uint32_t*)¤t_packet_header.routing.dst_mesh_id; + tt::tt_fabric::chan_id_t my_chan = routing_table->intra_mesh_table.dest_entry[routing_table->my_device_id]; + tt::tt_fabric::chan_id_t mcast_channel = 0; + if (routing_table->port_direction.east == my_chan) { + mcast_channel = routing_table->port_direction.west; + mcast_direction = 1; + } else if (routing_table->port_direction.west == my_chan) { + mcast_channel = routing_table->port_direction.east; + mcast_direction = 0; + } else if (routing_table->port_direction.north == my_chan) { + mcast_channel = routing_table->port_direction.south; + mcast_direction = 3; + } else if (routing_table->port_direction.south == my_chan) { + mcast_channel = routing_table->port_direction.north; + mcast_direction = 2; + } + mcast_router_noc_xy = eth_chan_to_noc_xy[noc_index][mcast_channel]; } inline uint32_t inc_ptr_with_wrap(uint32_t ptr, uint32_t inc) { uint32_t temp = ptr + inc; - if (temp >= buffer_size_2x) { - temp -= buffer_size_2x; + if (temp >= buffer_size) { + temp -= buffer_size; } return temp; } inline void advance_local_wrptr(uint32_t num_words) { inbound_wrptr.ptr = inc_ptr_with_wrap(inbound_wrptr.ptr, num_words); + words_inbound += num_words; } - inline void advance_out_rdptr(uint32_t num_words) { fvc_out_rdptr = inc_ptr_with_wrap(fvc_out_rdptr, num_words); } - - inline uint32_t words_before_buffer_wrap(uint32_t ptr) { - if (ptr >= buffer_size) { - return buffer_size_2x - ptr; - } else { - return buffer_size - ptr; + template + FORCE_INLINE void advance_out_rdptr(uint32_t num_words) { + uint32_t temp = fvc_out_rdptr + num_words; + if (temp >= buffer_size) { + temp -= buffer_size; + } + fvc_out_rdptr = temp; + if constexpr (fvc_mode == FVC_MODE_ROUTER) { + words_inbound -= num_words; } } + FORCE_INLINE uint32_t words_before_buffer_wrap(uint32_t ptr) { return buffer_size - ptr; } + template - inline uint32_t get_num_words_available() { + FORCE_INLINE uint32_t get_num_words_available() { if constexpr (fvc_mode == FVC_MODE_ROUTER) { uint32_t new_words = *words_received; - advance_local_wrptr(new_words); *words_received_local_update = (-new_words) << REMOTE_DEST_BUF_WORDS_FREE_INC; + words_inbound += new_words; + uint32_t temp = inbound_wrptr.ptr + new_words; + if (temp >= buffer_size) { + temp -= buffer_size; + } + inbound_wrptr.ptr = temp; + free_sender_buffer_space(new_words); + return words_inbound; + } else { + return words_inbound; } - uint32_t wrptr = inbound_wrptr.ptr; - uint32_t words_occupied = 0; - if (fvc_out_rdptr != wrptr) { - words_occupied = wrptr > fvc_out_rdptr ? wrptr - fvc_out_rdptr : buffer_size_2x + wrptr - fvc_out_rdptr; - } - words_inbound = words_occupied; - return words_occupied; } - inline uint32_t get_num_words_free() { - uint32_t wrptr = inbound_wrptr.ptr; - uint32_t words_occupied = 0; - if (fvc_pull_rdptr != wrptr) { - words_occupied = wrptr > fvc_pull_rdptr ? wrptr - fvc_pull_rdptr : buffer_size_2x + wrptr - fvc_pull_rdptr; - } - return buffer_size - words_occupied; - } + FORCE_INLINE + uint32_t get_num_words_free() { return buffer_size - words_inbound; } template - inline bool get_curr_packet_valid() { + FORCE_INLINE bool get_curr_packet_valid() { if (!curr_packet_valid) { if (get_num_words_available() >= PACKET_HEADER_SIZE_WORDS) { // Wait for a full packet header to arrive before advancing to next packet. @@ -326,69 +432,48 @@ typedef struct fvc_producer_state { return this->curr_packet_valid; } - inline uint32_t get_local_buffer_read_addr() { - uint32_t addr = buffer_start; - uint32_t offset = fvc_out_rdptr; - if (offset >= buffer_size) { - offset -= buffer_size; - } - addr = addr + (offset * PACKET_WORD_SIZE_BYTES); - return addr; + FORCE_INLINE uint32_t get_local_buffer_read_addr() { + return buffer_start + (fvc_out_rdptr * PACKET_WORD_SIZE_BYTES); } - inline uint32_t get_local_buffer_write_addr() { - uint32_t addr = buffer_start; - uint32_t offset = inbound_wrptr.ptr; - if (offset >= buffer_size) { - offset -= buffer_size; - } - addr = addr + (offset * PACKET_WORD_SIZE_BYTES); - return addr; + FORCE_INLINE uint32_t get_local_buffer_write_addr() { + return buffer_start + (inbound_wrptr.ptr * PACKET_WORD_SIZE_BYTES); } - inline uint32_t words_before_local_buffer_wrap() { - if (inbound_wrptr.ptr >= buffer_size) { - return buffer_size_2x - inbound_wrptr.ptr; - } else { - return buffer_size - inbound_wrptr.ptr; - } - } + FORCE_INLINE uint32_t words_before_local_buffer_wrap() { return buffer_size - inbound_wrptr.ptr; } - template - inline void update_remote_rdptr_sent() { - if (inbound_wrptr.ptr != inbound_rdptr.ptr) { - inbound_rdptr.ptr = inbound_wrptr.ptr; - if constexpr (fvc_mode == FVC_MODE_ROUTER) { - inbound_rdptr_ack.ptr = inbound_wrptr.ptr; - internal_::eth_send_packet( - 0, - ((uint32_t)&inbound_rdptr_ack) / PACKET_WORD_SIZE_BYTES, - remote_ptr_update_addr / PACKET_WORD_SIZE_BYTES, - 1); - } - } + FORCE_INLINE void free_sender_buffer_space(uint32_t words) { + // send received word credits to receiver + eth_write_remote_reg((uint32_t)update_sender_buffer_space, words << REMOTE_DEST_BUF_WORDS_FREE_INC); } template - inline void update_remote_rdptr_cleared() { - if (fvc_pull_rdptr != inbound_rdptr.ptr_cleared) { - inbound_rdptr.ptr_cleared = fvc_pull_rdptr; - if constexpr (fvc_mode == FVC_MODE_ROUTER) { - inbound_rdptr_ack.ptr_cleared = fvc_pull_rdptr; - internal_::eth_send_packet( - 0, - ((uint32_t)&inbound_rdptr_ack) / PACKET_WORD_SIZE_BYTES, - remote_ptr_update_addr / PACKET_WORD_SIZE_BYTES, - 1); - } + FORCE_INLINE void free_receiver_buffer_space(uint32_t words) { + if constexpr (fvc_mode == FVC_MODE_ROUTER) { + // send received word credits to receiver + eth_write_remote_reg((uint32_t)update_receiver_buffer_space, words << REMOTE_DEST_BUF_WORDS_FREE_INC); + } else { + words_inbound -= words; } } template - inline void advance_next_packet() { + FORCE_INLINE void advance_next_packet() { + // The following code makes the following assumptions regarding packet header structure + // 1 - packet_parameters is always the first word of the packet header, and doesn't span across packet word + // boundary. + // 2 - routing is always the last word of the packet header doesn't span across packet word boundary. + static_assert( + offsetof(packet_header_t, packet_parameters) == 0 && sizeof(packet_params) <= PACKET_WORD_SIZE_BYTES); + static_assert( + offsetof(packet_header_t, routing) >= (PACKET_HEADER_SIZE_BYTES - PACKET_WORD_SIZE_BYTES) && + offsetof(packet_header_t, routing) % PACKET_WORD_SIZE_BYTES + sizeof(tt_routing) <= PACKET_WORD_SIZE_BYTES); tt_l1_ptr uint32_t* packet_header_ptr = (uint32_t*)¤t_packet_header; volatile tt_l1_ptr uint32_t* next_header_ptr = reinterpret_cast(get_local_buffer_read_addr()); + // TODO: Should we just extract the specific field we want here (mcast_params) + packet_params* next_packet_params_ptr = (packet_params*)(next_header_ptr); + tt_routing* next_routing_ptr; uint32_t words_before_wrap = words_before_buffer_wrap(fvc_out_rdptr); uint32_t dwords_to_copy = PACKET_HEADER_SIZE_BYTES / 4; if (words_before_wrap < PACKET_HEADER_SIZE_WORDS) { @@ -403,16 +488,58 @@ typedef struct fvc_producer_state { for (uint32_t i = 0; i < dwords_after_wrap; i++) { packet_header_ptr[i + dwords_before_wrap] = next_header_ptr[i]; } + next_routing_ptr = + (tt_routing*)(next_header_ptr + packet_header_routing_offset_dwords - dwords_before_wrap); } else { +#pragma GCC unroll 12 for (uint32_t i = 0; i < dwords_to_copy; i++) { packet_header_ptr[i] = next_header_ptr[i]; } + next_routing_ptr = (tt_routing*)(next_header_ptr + packet_header_routing_offset_dwords); } this->packet_words_remaining = (this->current_packet_header.routing.packet_size_bytes + PACKET_WORD_SIZE_BYTES - 1) >> 4; if (tt_fabric_is_header_valid(¤t_packet_header)) { this->curr_packet_valid = true; + if (packet_is_for_local_chip()) { + if (packet_mcast_is_required()) { + // If its mcast packet, update the hop count. + // Packet arrival on this router accounts for 1 hop. + // Decrement respective direction hop count and determine + // whether mcast needs further hops. + update_mcast_hops(next_packet_params_ptr, next_routing_ptr); + } + // After updating mcast hop counts, we check whether the mcast packet still qualifies to be + // an mcast packet. + packet_processing_flags = + packet_mcast_is_active() ? ProcessingFlags::MCAST_DEST : ProcessingFlags::UCAST_DEST; + } else { + if (packet_mcast_is_active()) { + // Mcast packets have dest dev/mesh id set to the device where mcast starts. + // Hence packet_is_for_local_chip() returns true only for the first mcast target device. + // All other devices, need to check for mcast active flag to determine if they should consume + // the data or not. + // Any device that receives a packet with mcast active flag set consumes the data and forwards + // as well if its not the last hop of mcast group. + + // If mcast is active, update the hop count. + // Decrement hop count. + update_mcast_hops(next_packet_params_ptr, next_routing_ptr); + // After decrementing hop count, check whether mcast is stil active. + // If mcast has been deactivated here, that means this chip is the last hop for mcast packet. + // If so, we service the last hop as unicast dest. + // Otherwise, we handle as mcast dest, which means packet is consumed locally as well as + // forwarded to next hop in mcast direction. + packet_processing_flags = + packet_mcast_is_active() ? ProcessingFlags::MCAST_DEST : ProcessingFlags::UCAST_DEST; + } else { + // We are here for one of 2 reasons. + // 1 - Packet is not meant for this chip. + // 2 - Packet is not under active mcast. + packet_processing_flags = ProcessingFlags::NOT_DEST; + } + } } else { this->packet_corrupted = true; } @@ -445,25 +572,28 @@ typedef struct fvc_producer_state { if (packet_in_progress == 0) { if (current_packet_header.routing.flags == INLINE_FORWARD) { copy_header((pull_request_t*)&local_pull_request->pull_request); + words_cleared = words_available; } else { - local_pull_request->pull_request.wr_ptr = inc_ptr_with_wrap(fvc_out_rdptr, words_available); local_pull_request->pull_request.rd_ptr = fvc_out_rdptr; local_pull_request->pull_request.size = current_packet_header.routing.packet_size_bytes; local_pull_request->pull_request.buffer_size = buffer_size; local_pull_request->pull_request.buffer_start = xy_local_addr + buffer_start; + local_pull_request->pull_request.words_written = words_available; + local_pull_request->pull_request.words_read = 0; + words_cleared = 0; local_pull_request->pull_request.ack_addr = - xy_local_addr + (uint32_t)&local_pull_request->pull_request.rd_ptr; + xy_local_addr + (uint32_t)&local_pull_request->pull_request.words_read; local_pull_request->pull_request.flags = FORWARD; packet_in_progress = 1; } packet_words_remaining -= words_available; - advance_out_rdptr(words_available); + advance_out_rdptr(words_available); // issue noc write to noc target of pull request. uint64_t dest_addr = socket_mode == false ? ((uint64_t)get_next_hop_router_noc_xy() << 32) | FABRIC_ROUTER_REQ_QUEUE_START : ((uint64_t)current_packet_header.session.target_offset_h << 32) | current_packet_header.session.target_offset_l; - packet_dest = tt_fabric_send_pull_request(dest_addr, local_pull_request); + hop_dest = tt_fabric_send_pull_request(dest_addr, local_pull_request); if (current_packet_header.routing.flags == INLINE_FORWARD) { curr_packet_valid = false; flush_async_writes(); @@ -472,28 +602,32 @@ typedef struct fvc_producer_state { } else { // pull_request.rd_ptr is updated by remote puller when data is read out of producer's local buffer. // it is used to determine when it it safe to reclaim local buffer memory for more data. - fvc_pull_rdptr = local_pull_request->pull_request.rd_ptr; + uint32_t curr_words_read = local_pull_request->pull_request.words_read; + uint32_t words_to_clear = curr_words_read - words_cleared; + if (words_to_clear) { + free_receiver_buffer_space(words_to_clear); + words_cleared += words_to_clear; + } if (packet_words_remaining) { if (words_available) { - advance_out_rdptr(words_available); + advance_out_rdptr(words_available); // packet_dest is returned by tt_fabric_send_pull_request() as the address of request q entry + - // pull_request.wr_ptr. - noc_inline_dw_write(packet_dest, fvc_out_rdptr); + // pull_request.words_written. + local_pull_request->pull_request.words_written += words_available; + noc_inline_dw_write(hop_dest, local_pull_request->pull_request.words_written); packet_words_remaining -= words_available; } - } else if (fvc_pull_rdptr == fvc_out_rdptr) { + } else if (curr_words_read == local_pull_request->pull_request.words_written) { // all data has been pulled and cleared from local buffer packet_in_progress = 0; curr_packet_valid = false; } } - // send ptr cleared to ethernet sender. - update_remote_rdptr_cleared(); return words_available; } template - inline uint32_t issue_async_write() { + FORCE_INLINE uint32_t issue_async_write() { if constexpr (resample) { get_num_words_available(); } @@ -504,20 +638,184 @@ typedef struct fvc_producer_state { noc_async_write(get_local_buffer_read_addr(), packet_dest, words_available * PACKET_WORD_SIZE_BYTES); packet_words_remaining -= words_available; advance_out_rdptr(words_available); + words_cleared += words_available; packet_dest += words_available * PACKET_WORD_SIZE_BYTES; - // if (packet_words_remaining == 0) { - // packet_end_flush = 1; - // } } return words_available; } - inline bool packet_is_for_local_chip() { return my_id == *packet_id; } + FORCE_INLINE bool packet_is_for_local_chip() { return my_id == *packet_id; } + + inline bool packet_mcast_is_active() { return (current_packet_header.routing.flags & MCAST_ACTIVE) != 0; } + + inline bool packet_mcast_is_required() { return (current_packet_header.routing.flags & MCAST_DATA) != 0; } + + inline void update_mcast_hops(packet_params* packet_parameters, tt_routing* routing) { + uint32_t hop_count = 0; + if (mcast_direction == 0) { + hop_count = current_packet_header.packet_parameters.mcast_parameters.east; + if (hop_count) { + hop_count--; + current_packet_header.packet_parameters.mcast_parameters.east = hop_count; + packet_parameters->mcast_parameters.east = hop_count; + } + } else if (mcast_direction == 1) { + hop_count = current_packet_header.packet_parameters.mcast_parameters.west; + if (hop_count) { + hop_count--; + current_packet_header.packet_parameters.mcast_parameters.west = hop_count; + packet_parameters->mcast_parameters.west = hop_count; + } + } else if (mcast_direction == 2) { + hop_count = current_packet_header.packet_parameters.mcast_parameters.north; + if (hop_count) { + hop_count--; + current_packet_header.packet_parameters.mcast_parameters.north = hop_count; + packet_parameters->mcast_parameters.north = hop_count; + } + } else if (mcast_direction == 3) { + hop_count = current_packet_header.packet_parameters.mcast_parameters.south; + if (hop_count) { + hop_count--; + current_packet_header.packet_parameters.mcast_parameters.south = hop_count; + packet_parameters->mcast_parameters.south = hop_count; + } + } + if (hop_count == 0) { + // on last hop clear the mcast flag bits. + // last hop treats packet as normal ucast async write. + current_packet_header.routing.flags &= ~(MCAST_ACTIVE | MCAST_DATA); + routing->flags &= ~(MCAST_ACTIVE | MCAST_DATA); + } else { + current_packet_header.routing.flags |= MCAST_ACTIVE; + routing->flags |= MCAST_ACTIVE; + // calculate new header checksum after mcast updates. + tt_fabric_add_header_checksum(¤t_packet_header); + // copy new checksum to packet header in fvc buffer. + packet_parameters->misc_parameters.words[0] = + current_packet_header.packet_parameters.misc_parameters.words[0]; + } + } + + template + inline uint32_t process_mcast_packet() { + uint32_t words_processed = 0; + if (current_packet_header.session.command & ASYNC_WR) { + uint32_t words_available = get_num_words_available(); + words_available = std::min(words_available, packet_words_remaining); + words_processed = words_available; + if (packet_in_progress == 0) { + local_pull_request->pull_request.rd_ptr = fvc_out_rdptr; + local_pull_request->pull_request.size = current_packet_header.routing.packet_size_bytes; + local_pull_request->pull_request.buffer_size = buffer_size; + local_pull_request->pull_request.buffer_start = xy_local_addr + buffer_start; + local_pull_request->pull_request.words_written = words_available; + local_pull_request->pull_request.words_read = 0; + words_cleared = 0; + local_pull_request->pull_request.ack_addr = + xy_local_addr + (uint32_t)&local_pull_request->pull_request.words_read; + local_pull_request->pull_request.flags = FORWARD; + + packet_words_remaining -= words_available; + // issue noc write to noc target of pull request. + // figure out next hop for mcast forwarding + uint64_t dest_addr = ((uint64_t)mcast_router_noc_xy << 32) | FABRIC_ROUTER_REQ_QUEUE_START; + hop_dest = tt_fabric_send_pull_request(dest_addr, local_pull_request); + + packet_dest = ((uint64_t)current_packet_header.session.target_offset_h << 32) | + current_packet_header.session.target_offset_l; + + advance_out_rdptr(PACKET_HEADER_SIZE_WORDS); + words_available -= PACKET_HEADER_SIZE_WORDS; + + uint32_t local_words_available = std::min(words_available, words_before_buffer_wrap(fvc_out_rdptr)); + // write available data till end of input buffer + if (local_words_available) { + // need to check local_words_available > 0 since it is possible that we only received the packet + // header so far, and words_available == 0 after words_available -= PACKET_HEADER_SIZE_WORDS above. + noc_async_write( + get_local_buffer_read_addr(), packet_dest, local_words_available * PACKET_WORD_SIZE_BYTES); + advance_out_rdptr(local_words_available); + packet_dest += local_words_available * PACKET_WORD_SIZE_BYTES; + } + local_words_available = words_available - local_words_available; + // write remaining available data from beginning of buffer + if (local_words_available) { + noc_async_write( + get_local_buffer_read_addr(), packet_dest, local_words_available * PACKET_WORD_SIZE_BYTES); + advance_out_rdptr(local_words_available); + packet_dest += local_words_available * PACKET_WORD_SIZE_BYTES; + } + // subtract the header words. Remaining words are the data to be written to packet_dest. + // Remember to account for trailing bytes which may not be a full packet word. + packet_in_progress = 1; + } else { + noc_async_writes_flushed(); + // pull_request.rd_ptr is updated by remote puller when data is read out of producer's local buffer. + // it is used to determine when it it safe to reclaim local buffer memory for more data. + uint32_t curr_words_read = local_pull_request->pull_request.words_read; + uint32_t words_to_clear = curr_words_read - words_cleared; + if (words_to_clear) { + free_receiver_buffer_space(words_to_clear); + words_cleared += words_to_clear; + } + + if (packet_words_remaining) { + if (words_available) { + uint32_t local_words_available = + std::min(words_available, words_before_buffer_wrap(fvc_out_rdptr)); + // write available data till end of input buffer + noc_async_write( + get_local_buffer_read_addr(), packet_dest, local_words_available * PACKET_WORD_SIZE_BYTES); + advance_out_rdptr(local_words_available); + packet_dest += local_words_available * PACKET_WORD_SIZE_BYTES; + local_words_available = words_available - local_words_available; + // write remaining available data from beginning of buffer + if (local_words_available) { + noc_async_write( + get_local_buffer_read_addr(), + packet_dest, + local_words_available * PACKET_WORD_SIZE_BYTES); + advance_out_rdptr(local_words_available); + packet_dest += local_words_available * PACKET_WORD_SIZE_BYTES; + } + + // hop_dest is returned by tt_fabric_send_pull_request() as the address of request q entry + + // pull_request.wr_ptr. + local_pull_request->pull_request.words_written += words_available; + noc_inline_dw_write(hop_dest, local_pull_request->pull_request.words_written); + packet_words_remaining -= words_available; + } + } else if (curr_words_read == local_pull_request->pull_request.words_written) { + // for fused command issue the atomic inc before invalidating the current packet + if (current_packet_header.session.command & ATOMIC_INC) { + uint64_t noc_addr = + ((uint64_t)current_packet_header.packet_parameters.async_wr_atomic_parameters.noc_xy + << 32) | + current_packet_header.packet_parameters.async_wr_atomic_parameters.l1_offset; + noc_fast_atomic_increment( + noc_index, + NCRISC_AT_CMD_BUF, + noc_addr, + NOC_UNICAST_WRITE_VC, + current_packet_header.packet_parameters.async_wr_atomic_parameters.increment, + 31, + false); + } + // all data has been pulled and cleared from local buffer + packet_in_progress = 0; + curr_packet_valid = false; + packet_timestamp = get_timestamp(); + } + } + } + return words_processed; + } template inline uint32_t process_inbound_packet() { uint32_t words_processed = 0; - if (packet_is_for_local_chip()) { + if (packet_processing_flags == ProcessingFlags::UCAST_DEST) { if (current_packet_header.routing.flags == FORWARD) { if (current_packet_header.session.command & ASYNC_WR) { if (packet_in_progress == 0) { @@ -525,10 +823,10 @@ typedef struct fvc_producer_state { current_packet_header.session.target_offset_l; packet_words_remaining -= PACKET_HEADER_SIZE_WORDS; advance_out_rdptr(PACKET_HEADER_SIZE_WORDS); + words_cleared = PACKET_HEADER_SIZE_WORDS; // subtract the header words. Remaining words are the data to be written to packet_dest. // Remember to account for trailing bytes which may not be a full packet word. packet_in_progress = 1; - words_inbound -= PACKET_HEADER_SIZE_WORDS; issue_async_write(); } else { flush_async_writes(); @@ -575,12 +873,13 @@ typedef struct fvc_producer_state { packet_words_remaining -= PACKET_HEADER_SIZE_WORDS; advance_out_rdptr(PACKET_HEADER_SIZE_WORDS); words_processed = PACKET_HEADER_SIZE_WORDS; - fvc_pull_rdptr = fvc_out_rdptr; - update_remote_rdptr_cleared(); + free_receiver_buffer_space(PACKET_HEADER_SIZE_WORDS); curr_packet_valid = false; packet_timestamp = get_timestamp(); } } + } else if (packet_processing_flags == ProcessingFlags::MCAST_DEST) { + words_processed = process_mcast_packet(); } else { pull_data_from_fvc_buffer(); } @@ -588,19 +887,10 @@ typedef struct fvc_producer_state { } template - inline void flush_async_writes() { + FORCE_INLINE void flush_async_writes() { noc_async_writes_flushed(); - fvc_pull_rdptr = fvc_out_rdptr; - update_remote_rdptr_cleared(); - } - - inline void check_packet_end_flush() { - if (packet_end_flush) { - flush_async_writes(); - packet_in_progress = 0; - curr_packet_valid = false; - packet_end_flush = 0; - } + free_receiver_buffer_space(words_cleared); + words_cleared = 0; } } fvc_producer_state_t; @@ -1022,7 +1312,7 @@ typedef struct socket_reader_state { inline uint32_t get_num_words_to_pull(volatile pull_request_t* pull_request) { uint32_t num_words_to_pull = num_words_available_to_pull(pull_request); - uint32_t num_words_before_wrap = words_before_buffer_wrap(pull_request->buffer_size, pull_request->rd_ptr); + uint32_t num_words_before_wrap = words_before_pull_buffer_wrap(pull_request->buffer_size, pull_request->rd_ptr); num_words_to_pull = std::min(num_words_to_pull, num_words_before_wrap); uint32_t socket_buffer_space = get_num_words_free(); @@ -1089,8 +1379,8 @@ typedef struct socket_reader_state { uint32_t dest_addr = 0; // should be second half of fvc buffer. uint32_t words_remaining = total_words_to_forward; while (words_remaining) { - uint32_t num_words_before_local_wrap = words_before_buffer_wrap(buffer_size, fvc_out_rdptr); - uint32_t num_words_before_remote_wrap = words_before_buffer_wrap(buffer_size, fvc_out_wrptr); + uint32_t num_words_before_local_wrap = words_before_pull_buffer_wrap(buffer_size, fvc_out_rdptr); + uint32_t num_words_before_remote_wrap = words_before_pull_buffer_wrap(buffer_size, fvc_out_wrptr); uint32_t words_to_forward = std::min(num_words_before_local_wrap, num_words_before_remote_wrap); words_to_forward = std::min(words_to_forward, words_remaining); // max 8K bytes @@ -1184,147 +1474,25 @@ inline bool fvc_req_valid(const volatile chan_req_buf* req_buf) { } inline uint32_t num_words_available_to_pull(volatile pull_request_t* pull_request) { - uint32_t wr_ptr = pull_request->wr_ptr; - uint32_t rd_ptr = pull_request->rd_ptr; - uint32_t buf_size = pull_request->buffer_size; - - if (wr_ptr == rd_ptr) { - // buffer empty. - return 0; - } - uint32_t num_words = wr_ptr > rd_ptr ? wr_ptr - rd_ptr : buf_size * 2 + wr_ptr - rd_ptr; - - // num_words = std::min(num_words, this->get_curr_packet_words_remaining()); - return num_words; + return pull_request->words_written - pull_request->words_read; } inline uint32_t advance_ptr(uint32_t buffer_size, uint32_t ptr, uint32_t inc_words) { uint32_t temp = ptr + inc_words; - if (temp >= buffer_size * 2) { - temp -= buffer_size * 2; + if (temp >= buffer_size) { + temp -= buffer_size; } return temp; } -inline uint32_t words_before_buffer_wrap(uint32_t buffer_size, uint32_t rd_ptr) { - if (rd_ptr >= buffer_size) { - return buffer_size * 2 - rd_ptr; - } else { - return buffer_size - rd_ptr; - } -} +inline uint32_t words_before_pull_buffer_wrap(uint32_t buffer_size, uint32_t rd_ptr) { return buffer_size - rd_ptr; } -inline uint32_t get_rd_ptr_offset_words(pull_request_t* pull_request) { - uint32_t offset = pull_request->rd_ptr; - if (pull_request->rd_ptr >= pull_request->buffer_size) { - offset -= pull_request->buffer_size; - } - return offset; -} +inline uint32_t get_rd_ptr_offset_words(pull_request_t* pull_request) { return pull_request->rd_ptr; } inline void update_pull_request_words_cleared(pull_request_t* pull_request) { - noc_inline_dw_write(pull_request->ack_addr, pull_request->rd_ptr); -} - -inline uint32_t get_num_words_to_pull(volatile pull_request_t* pull_request, fvc_consumer_state_t* fvc_consumer_state) { - uint32_t num_words_to_pull = num_words_available_to_pull(pull_request); - uint32_t num_words_before_wrap = words_before_buffer_wrap(pull_request->buffer_size, pull_request->rd_ptr); - - num_words_to_pull = std::min(num_words_to_pull, num_words_before_wrap); - uint32_t fvc_buffer_space = fvc_consumer_state->get_num_words_free(); - num_words_to_pull = std::min(num_words_to_pull, fvc_buffer_space); - - if (num_words_to_pull == 0) { - return 0; - } - - uint32_t fvc_space_before_wptr_wrap = fvc_consumer_state->words_before_local_buffer_wrap(); - num_words_to_pull = std::min(num_words_to_pull, fvc_space_before_wptr_wrap); - - num_words_to_pull = std::min(num_words_to_pull, fvc_consumer_state->buffer_size / 2); - - return num_words_to_pull; -} - -inline uint32_t pull_data_to_fvc_buffer( - volatile pull_request_t* pull_request, fvc_consumer_state_t* fvc_consumer_state) { - if (fvc_consumer_state->packet_in_progress == 0) { - uint32_t size = pull_request->size; - fvc_consumer_state->packet_words_remaining = (size + PACKET_WORD_SIZE_BYTES - 1) >> 4; - fvc_consumer_state->packet_in_progress = 1; - } - - uint32_t num_words_to_pull = get_num_words_to_pull(pull_request, fvc_consumer_state); - if (num_words_to_pull == 0) { - return 0; - } - - uint32_t rd_offset = get_rd_ptr_offset_words((pull_request_t*)pull_request); - uint64_t src_addr = pull_request->buffer_start + (rd_offset * PACKET_WORD_SIZE_BYTES); - uint32_t fvc_addr = fvc_consumer_state->get_local_buffer_pull_addr(); - - // pull_data_from_remote(); - noc_async_read(src_addr, fvc_addr, num_words_to_pull * PACKET_WORD_SIZE_BYTES); - fvc_consumer_state->register_pull_data(num_words_to_pull); - pull_request->rd_ptr = advance_ptr(pull_request->buffer_size, pull_request->rd_ptr, num_words_to_pull); - - // TODO: this->remote_wptr_update(num_words_to_forward); - - return num_words_to_pull; + noc_inline_dw_write(pull_request->ack_addr, pull_request->words_read); } -inline uint32_t move_data_to_fvc_buffer( - volatile pull_request_t* pull_request, fvc_consumer_state_t* fvc_consumer_state) { - if (fvc_consumer_state->packet_in_progress == 0) { - fvc_consumer_state->packet_words_remaining = PACKET_HEADER_SIZE_WORDS; - fvc_consumer_state->packet_in_progress = 1; - } - - // if fvc does not have enough space, try again later. - if (fvc_consumer_state->get_num_words_free() < PACKET_HEADER_SIZE_WORDS) { - return 0; - } - - uint32_t fvc_space_before_wptr_wrap = fvc_consumer_state->words_before_local_buffer_wrap(); - uint32_t* fvc_addr = (uint32_t*)fvc_consumer_state->get_local_buffer_pull_addr(); - uint32_t* src = (uint32_t*)pull_request; - - switch (fvc_space_before_wptr_wrap) { - case 1: - fvc_addr[0] = src[0]; - fvc_addr[1] = src[1]; - fvc_addr[2] = src[2]; - fvc_addr[3] = src[3]; - fvc_addr = (uint32_t*)fvc_consumer_state->buffer_start; - fvc_addr[0] = src[4]; - fvc_addr[1] = src[5]; - fvc_addr[2] = src[6]; - fvc_addr[3] = src[7]; - fvc_addr[4] = src[8]; - fvc_addr[5] = src[9]; - fvc_addr[6] = src[10]; - fvc_addr[7] = src[11]; - break; - case 2: - // uint32_t i = 0; - for (uint32_t i = 0; i < (PACKET_HEADER_SIZE_WORDS - 1) * PACKET_WORD_SIZE_BYTES / 4; i++) { - fvc_addr[i] = src[i]; - } - fvc_addr = (uint32_t*)fvc_consumer_state->buffer_start; - fvc_addr[0] = src[8]; - fvc_addr[1] = src[9]; - fvc_addr[2] = src[10]; - fvc_addr[3] = src[11]; - break; - default: - for (uint32_t i = 0; i < PACKET_HEADER_SIZE_BYTES / 4; i++) { - fvc_addr[i] = src[i]; - } - } - - fvc_consumer_state->register_move_data(PACKET_HEADER_SIZE_WORDS); - return PACKET_HEADER_SIZE_WORDS; -} /** * Polling for ready signal from the remote peers of all input and output queues. * Blocks until all are ready, but doesn't block polling on each individual queue. @@ -1418,8 +1586,8 @@ inline uint64_t tt_fabric_send_pull_request(uint64_t dest_addr, volatile local_p // This will happen, if the producer did not have all the availale data in its buffer when // the pull request was first issued. In this case, as the producer gets more data in its buffer, // it updates write pointer in the consumer request buffer pull request entry. - uint64_t wr_ptr_addr = noc_addr + offsetof(pull_request_t, wr_ptr); - return wr_ptr_addr; + uint64_t words_written_addr = noc_addr + offsetof(pull_request_t, words_written); + return words_written_addr; } inline void tt_fabric_init() { diff --git a/tt_fabric/hw/inc/tt_fabric_api.h b/tt_fabric/hw/inc/tt_fabric_api.h index 67c1c3c322b..63fa69e4688 100644 --- a/tt_fabric/hw/inc/tt_fabric_api.h +++ b/tt_fabric/hw/inc/tt_fabric_api.h @@ -38,8 +38,10 @@ inline void fabric_setup_pull_request(uint32_t src_addr, uint32_t size) { client_interface->local_pull_request.pull_request.size = size; client_interface->local_pull_request.pull_request.buffer_size = size_in_words; client_interface->local_pull_request.pull_request.buffer_start = xy_local_addr + src_addr; + client_interface->local_pull_request.pull_request.words_written = size_in_words; + client_interface->local_pull_request.pull_request.words_read = 0; client_interface->local_pull_request.pull_request.ack_addr = - xy_local_addr + (uint32_t)&client_interface->local_pull_request.pull_request.rd_ptr; + xy_local_addr + (uint32_t)&client_interface->local_pull_request.pull_request.words_read; client_interface->local_pull_request.pull_request.flags = FORWARD; } @@ -90,6 +92,58 @@ inline void fabric_async_write( } } +inline void fabric_async_write_multicast_add_header( + uint32_t src_addr, // source address in sender’s memory + uint16_t dst_mesh_id, + uint16_t dst_dev_id, + uint64_t dst_addr, + uint32_t size, // number of bytes to write to remote destination + uint32_t e_depth, + uint32_t w_depth, + uint32_t n_depth, + uint32_t s_depth) { + packet_header_t* packet_header = (packet_header_t*)(src_addr); + packet_header->routing.flags = FORWARD | MCAST_DATA; + packet_header->routing.packet_size_bytes = size; + packet_header->routing.dst_mesh_id = dst_mesh_id; + packet_header->routing.dst_dev_id = dst_dev_id; + packet_header->session.command = ASYNC_WR; + packet_header->session.target_offset_l = (uint32_t)dst_addr; + packet_header->session.target_offset_h = dst_addr >> 32; + packet_header->packet_parameters.mcast_parameters.east = e_depth; + packet_header->packet_parameters.mcast_parameters.west = w_depth; + packet_header->packet_parameters.mcast_parameters.north = n_depth; + packet_header->packet_parameters.mcast_parameters.south = s_depth; + tt_fabric_add_header_checksum(packet_header); +} +// Write packetized data over fabric to dst_mesh, dst_dev. +// Packet is at src_addr in sender L1. +template +inline void fabric_async_write_multicast( + uint32_t routing_plane, // the network plane to use for this transaction + uint32_t src_addr, // source address in sender’s memory + uint16_t dst_mesh_id, + uint16_t dst_dev_id, + uint64_t dst_addr, + uint32_t size, // number of bytes to write to remote destination + uint32_t e_depth, + uint32_t w_depth, + uint32_t n_depth, + uint32_t s_depth) { + if constexpr (mode == ASYNC_WR_ALL or mode == ASYNC_WR_ADD_HEADER) { + fabric_async_write_multicast_add_header( + src_addr, dst_mesh_id, dst_dev_id, dst_addr, size, e_depth, w_depth, n_depth, s_depth); + } + + if constexpr (mode == ASYNC_WR_ALL or mode == ASYNC_WR_ADD_PR) { + fabric_setup_pull_request(src_addr, size); + } + + if constexpr (mode == ASYNC_WR_ALL or mode == ASYNC_WR_SEND) { + fabric_send_pull_request(routing_plane, dst_mesh_id, dst_dev_id); + } +} + inline void send_message_to_gk() { uint64_t gk_noc_base = client_interface->gk_msg_buf_addr; uint64_t noc_addr = gk_noc_base + offsetof(ctrl_chan_msg_buf, wrptr); diff --git a/tt_fabric/hw/inc/tt_fabric_interface.h b/tt_fabric/hw/inc/tt_fabric_interface.h index cb4967d21cd..1c4f69afe09 100644 --- a/tt_fabric/hw/inc/tt_fabric_interface.h +++ b/tt_fabric/hw/inc/tt_fabric_interface.h @@ -35,7 +35,7 @@ constexpr uint32_t FVC_SYNC_THRESHOLD = 256; #define SOCKET_CONNECT (0x1 << 10) #define INVALID 0x0 -#define DATA 0x1 +#define MCAST_ACTIVE 0x1 #define MCAST_DATA 0x2 #define SYNC 0x4 #define FORWARD 0x8 @@ -70,11 +70,11 @@ typedef struct _tt_session { static_assert(sizeof(tt_session) == 20); typedef struct _mcast_params { + uint32_t socket_id; // Socket Id for DSocket Multicast. Ignored for ASYNC multicast. uint16_t east; uint16_t west; uint16_t north; uint16_t south; - uint32_t socket_id; // Socket Id for DSocket Multicast. Ignored for ASYNC multicast. } mcast_params; typedef struct _socket_params { @@ -128,11 +128,15 @@ typedef struct _packet_header { tt_routing routing; } packet_header_t; -const uint32_t PACKET_HEADER_SIZE_BYTES = 48; -const uint32_t PACKET_HEADER_SIZE_WORDS = PACKET_HEADER_SIZE_BYTES / PACKET_WORD_SIZE_BYTES; +constexpr uint32_t PACKET_HEADER_SIZE_BYTES = 48; +constexpr uint32_t PACKET_HEADER_SIZE_WORDS = PACKET_HEADER_SIZE_BYTES / PACKET_WORD_SIZE_BYTES; static_assert(sizeof(packet_header_t) == PACKET_HEADER_SIZE_BYTES); +static_assert(offsetof(packet_header_t, routing) % 4 == 0); + +constexpr uint32_t packet_header_routing_offset_dwords = offsetof(packet_header_t, routing) / 4; + void tt_fabric_add_header_checksum(packet_header_t* p_header) { uint16_t* ptr = (uint16_t*)p_header; uint32_t sum = 0; @@ -180,11 +184,13 @@ typedef struct _pull_request { uint64_t buffer_start; // Producer local buffer start. Used for wrapping rd/wr_ptr at the end of buffer. uint64_t ack_addr; // Producer local address to send rd_ptr updates. fabric router pushes its rd_ptr to requestor // at this address. - uint8_t padding[15]; + uint32_t words_written; + uint32_t words_read; + uint8_t padding[7]; uint8_t flags; // Router command. } pull_request_t; -const uint32_t PULL_REQ_SIZE_BYTES = 48; +constexpr uint32_t PULL_REQ_SIZE_BYTES = 48; static_assert(sizeof(pull_request_t) == PULL_REQ_SIZE_BYTES); static_assert(sizeof(pull_request_t) == sizeof(packet_header_t)); @@ -195,18 +201,18 @@ typedef union _chan_request_entry { uint8_t bytes[48]; } chan_request_entry_t; -const uint32_t CHAN_PTR_SIZE_BYTES = 16; +constexpr uint32_t CHAN_PTR_SIZE_BYTES = 16; typedef struct _chan_ptr { uint32_t ptr; uint32_t pad[3]; } chan_ptr; static_assert(sizeof(chan_ptr) == CHAN_PTR_SIZE_BYTES); -const uint32_t CHAN_REQ_BUF_LOG_SIZE = 4; // must be 2^N -const uint32_t CHAN_REQ_BUF_SIZE = 16; // must be 2^N -const uint32_t CHAN_REQ_BUF_SIZE_MASK = (CHAN_REQ_BUF_SIZE - 1); -const uint32_t CHAN_REQ_BUF_PTR_MASK = ((CHAN_REQ_BUF_SIZE << 1) - 1); -const uint32_t CHAN_REQ_BUF_SIZE_BYTES = 2 * CHAN_PTR_SIZE_BYTES + CHAN_REQ_BUF_SIZE * PULL_REQ_SIZE_BYTES; +constexpr uint32_t CHAN_REQ_BUF_LOG_SIZE = 4; // must be 2^N +constexpr uint32_t CHAN_REQ_BUF_SIZE = 16; // must be 2^N +constexpr uint32_t CHAN_REQ_BUF_SIZE_MASK = (CHAN_REQ_BUF_SIZE - 1); +constexpr uint32_t CHAN_REQ_BUF_PTR_MASK = ((CHAN_REQ_BUF_SIZE << 1) - 1); +constexpr uint32_t CHAN_REQ_BUF_SIZE_BYTES = 2 * CHAN_PTR_SIZE_BYTES + CHAN_REQ_BUF_SIZE * PULL_REQ_SIZE_BYTES; typedef struct _chan_req_buf { chan_ptr wrptr; @@ -234,12 +240,12 @@ static_assert(sizeof(chan_payload_ptr) == CHAN_PTR_SIZE_BYTES); // Each control channel message is 48 Bytes. // FVCC buffer is a 16 message buffer each for incoming and outgoing messages. // Control message capacity can be increased by increasing FVCC_BUF_SIZE. -const uint32_t FVCC_BUF_SIZE = 16; // must be 2^N -const uint32_t FVCC_BUF_LOG_SIZE = 4; // must be log2(FVCC_BUF_SIZE) -const uint32_t FVCC_SIZE_MASK = (FVCC_BUF_SIZE - 1); -const uint32_t FVCC_PTR_MASK = ((FVCC_BUF_SIZE << 1) - 1); -const uint32_t FVCC_BUF_SIZE_BYTES = PULL_REQ_SIZE_BYTES * FVCC_BUF_SIZE + 2 * CHAN_PTR_SIZE_BYTES; -const uint32_t FVCC_SYNC_BUF_SIZE_BYTES = CHAN_PTR_SIZE_BYTES * FVCC_BUF_SIZE; +constexpr uint32_t FVCC_BUF_SIZE = 16; // must be 2^N +constexpr uint32_t FVCC_BUF_LOG_SIZE = 4; // must be log2(FVCC_BUF_SIZE) +constexpr uint32_t FVCC_SIZE_MASK = (FVCC_BUF_SIZE - 1); +constexpr uint32_t FVCC_PTR_MASK = ((FVCC_BUF_SIZE << 1) - 1); +constexpr uint32_t FVCC_BUF_SIZE_BYTES = PULL_REQ_SIZE_BYTES * FVCC_BUF_SIZE + 2 * CHAN_PTR_SIZE_BYTES; +constexpr uint32_t FVCC_SYNC_BUF_SIZE_BYTES = CHAN_PTR_SIZE_BYTES * FVCC_BUF_SIZE; inline bool fvcc_buf_ptrs_empty(uint32_t wrptr, uint32_t rdptr) { return (wrptr == rdptr); } diff --git a/tt_fabric/impl/kernels/tt_fabric_router.cpp b/tt_fabric/impl/kernels/tt_fabric_router.cpp index 335fcd170b8..5453c5f6ca3 100644 --- a/tt_fabric/impl/kernels/tt_fabric_router.cpp +++ b/tt_fabric/impl/kernels/tt_fabric_router.cpp @@ -17,8 +17,6 @@ fvcc_outbound_state_t fvcc_outbound_state __attribute__((aligned(16))); // outb #endif volatile local_pull_request_t local_pull_request_temp __attribute__((aligned(16))); // replicate for each fvc volatile local_pull_request_t* local_pull_request = &local_pull_request_temp; // replicate for each fvc -chan_payload_ptr inbound_rdptr_ack __attribute__((aligned(16))); -volatile chan_payload_ptr remote_rdptr __attribute__((aligned(16))); constexpr uint32_t fvc_data_buf_size_words = get_compile_time_arg_val(0); constexpr uint32_t fvc_data_buf_size_bytes = fvc_data_buf_size_words * PACKET_WORD_SIZE_BYTES; @@ -100,10 +98,6 @@ void kernel_main() { router_state.sync_in = 0; router_state.sync_out = 0; - inbound_rdptr_ack.ptr = 0; - inbound_rdptr_ack.ptr_cleared = 0; - inbound_rdptr_ack.pad[0] = 0; - inbound_rdptr_ack.pad[1] = 0; zero_l1_buf((tt_l1_ptr uint32_t*)fvc_consumer_req_buf, sizeof(chan_req_buf)); zero_l1_buf((tt_l1_ptr uint32_t*)FVCC_IN_BUF_START, FVCC_IN_BUF_SIZE); @@ -112,12 +106,10 @@ void kernel_main() { write_kernel_status(kernel_status, PQ_TEST_WORD_CNT_INDEX + 1, (uint32_t)&fvc_consumer_state); write_kernel_status(kernel_status, PQ_TEST_STATUS_INDEX + 1, (uint32_t)&fvc_producer_state); - fvc_consumer_state.init( - FABRIC_ROUTER_DATA_BUF_START, fvc_data_buf_size_words / 2, (uint32_t)&fvc_producer_state.inbound_wrptr); + fvc_consumer_state.init(FABRIC_ROUTER_DATA_BUF_START, fvc_data_buf_size_words / 2); fvc_producer_state.init( FABRIC_ROUTER_DATA_BUF_START + (fvc_data_buf_size_words * PACKET_WORD_SIZE_BYTES / 2), - fvc_data_buf_size_words / 2, - (uint32_t)&remote_rdptr); + fvc_data_buf_size_words / 2); #ifdef FVCC_SUPPORT fvcc_outbound_state.init( @@ -149,7 +141,7 @@ void kernel_main() { pull_request_t* pull_req = &req->pull_request; if (req->bytes[47] == FORWARD) { // Data is packetized. - pull_data_to_fvc_buffer(pull_req, &fvc_consumer_state); + fvc_consumer_state.pull_data_to_fvc_buffer(pull_req); if (fvc_consumer_state.packet_words_remaining == 0 || fvc_consumer_state.pull_words_in_flight >= FVC_SYNC_THRESHOLD) { fvc_consumer_state.total_words_to_forward += fvc_consumer_state.pull_words_in_flight; @@ -159,7 +151,7 @@ void kernel_main() { update_pull_request_words_cleared(pull_req); } } else if (req->bytes[47] == INLINE_FORWARD) { - move_data_to_fvc_buffer(pull_req, &fvc_consumer_state); + fvc_consumer_state.move_data_to_fvc_buffer(pull_req); } if (fvc_consumer_state.packet_in_progress == 1 and fvc_consumer_state.packet_words_remaining == 0) { @@ -176,7 +168,6 @@ void kernel_main() { } // Handle Ethernet Inbound Data - fvc_producer_state.update_remote_rdptr_sent(); if (fvc_producer_state.get_curr_packet_valid()) { fvc_producer_state.process_inbound_packet(); loop_count = 0; diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 768c9318eac..bee22b18640 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -16,6 +16,7 @@ target_link_libraries( magic_enum fmt::fmt-header-only span + small_vector TracyClient nlohmann_json::nlohmann_json TT::Metalium::HostDevCommon @@ -33,6 +34,7 @@ target_link_libraries( HAL::grayskull HAL::wormhole HAL::blackhole + FlatBuffers::FlatBuffers ) target_precompile_headers( diff --git a/tt_metal/api/tt-metalium/allocator.hpp b/tt_metal/api/tt-metalium/allocator.hpp index 0a3fa43818e..bcc190e2684 100644 --- a/tt_metal/api/tt-metalium/allocator.hpp +++ b/tt_metal/api/tt-metalium/allocator.hpp @@ -56,6 +56,9 @@ class Allocator { DeviceAddr get_base_allocator_addr(const HalMemType& mem_type) const; const AllocatorConfig& get_config() const; + // Alignment can be pulled out of the AllocatorConfig but this getter is a helper + // so client code does not need to condition based on BufferType + uint32_t get_alignment(BufferType buffer_type) const; Statistics get_statistics(const BufferType& buffer_type) const; MemoryBlockTable get_memory_block_table(const BufferType& buffer_type) const; diff --git a/tt_metal/api/tt-metalium/allocator_types.hpp b/tt_metal/api/tt-metalium/allocator_types.hpp index f918796c629..4a7b6ed2625 100644 --- a/tt_metal/api/tt-metalium/allocator_types.hpp +++ b/tt_metal/api/tt-metalium/allocator_types.hpp @@ -41,6 +41,7 @@ struct AllocatorConfig { size_t dram_bank_size = 0; std::vector dram_bank_offsets = {}; uint32_t dram_unreserved_base = 0; + uint32_t dram_alignment = 0; //! worker specific configuration uint32_t l1_unreserved_base = 0; CoreRangeSet worker_grid = {}; @@ -54,7 +55,7 @@ struct AllocatorConfig { BankMapping l1_bank_remap = {}; // for remapping which l1 bank points to which bank if we assume normal row-major assignment CoreRangeSet compute_grid = {}; - uint32_t alignment = 0; + uint32_t l1_alignment = 0; bool disable_interleaved = false; void reset(); ~AllocatorConfig() { reset(); } diff --git a/tt_metal/api/tt-metalium/device.hpp b/tt_metal/api/tt-metalium/device.hpp index 3a3238668d7..821eeaf5c9d 100644 --- a/tt_metal/api/tt-metalium/device.hpp +++ b/tt_metal/api/tt-metalium/device.hpp @@ -48,6 +48,10 @@ class CommandQueue; class TraceBuffer; struct TraceDescriptor; +namespace detail { +struct TraceDescriptor; +} + inline namespace v0 { class IDevice { @@ -157,7 +161,6 @@ class IDevice { virtual void initialize_and_launch_firmware() = 0; virtual void init_command_queue_host() = 0; virtual void init_command_queue_device() = 0; - virtual void update_dispatch_cores_for_multi_cq_eth_dispatch() = 0; // Puts device into reset virtual bool close() = 0; diff --git a/tt_metal/api/tt-metalium/device_impl.hpp b/tt_metal/api/tt-metalium/device_impl.hpp index a1ffd887efa..375e515ad62 100644 --- a/tt_metal/api/tt-metalium/device_impl.hpp +++ b/tt_metal/api/tt-metalium/device_impl.hpp @@ -149,7 +149,6 @@ class Device : public IDevice { void initialize_and_launch_firmware() override; void init_command_queue_host() override; void init_command_queue_device() override; - void update_dispatch_cores_for_multi_cq_eth_dispatch() override; // Puts device into reset bool close() override; @@ -211,6 +210,8 @@ class Device : public IDevice { void initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, tt::stl::Span l1_bank_remap); + void update_dispatch_cores_for_multi_cq_eth_dispatch(); + void compile_command_queue_programs(); void configure_command_queue_programs(); void clear_l1_state(); diff --git a/tt_metal/api/tt-metalium/dispatch_core_manager.hpp b/tt_metal/api/tt-metalium/dispatch_core_manager.hpp index 5e5676be3b5..62433e832b5 100644 --- a/tt_metal/api/tt-metalium/dispatch_core_manager.hpp +++ b/tt_metal/api/tt-metalium/dispatch_core_manager.hpp @@ -4,9 +4,12 @@ #pragma once +#include +#include +#include + #include "core_descriptor.hpp" #include "core_coord.hpp" -#include #include "dispatch_core_common.hpp" namespace tt::tt_metal { @@ -49,20 +52,9 @@ class dispatch_core_manager { //TODO: this should probably be in command_queue_interface.hpp, but it's here for now due to circular dependency static constexpr uint8_t MAX_NUM_HW_CQS = 2; - static void initialize(const DispatchCoreConfig &dispatch_core_config, uint8_t num_hw_cqs) noexcept { - log_debug(tt::LogMetal, "DevicePool initialize"); - if (_inst == nullptr) { - static dispatch_core_manager dispatch_core_manager(dispatch_core_config, num_hw_cqs); - _inst = &dispatch_core_manager; - } else if (_inst->dispatch_core_config_by_device[0] != dispatch_core_config or num_hw_cqs != _inst->num_hw_cqs) { - _inst->reset_dispatch_core_manager(dispatch_core_config, num_hw_cqs); - } - } - - static dispatch_core_manager &instance() { - TT_ASSERT(_inst != nullptr, "Trying to get dispatch_core_manager without initializing it"); - return *_inst; - } + static void initialize(const DispatchCoreConfig& dispatch_core_config, uint8_t num_hw_cqs) noexcept; + + static dispatch_core_manager& instance(); /// @brief Gets the location of the kernel desginated to read from the issue queue region from a particular command queue /// Each command queue has an issue queue where host enqueues commands. This core relays to the dispatcher core to interpret and launch @@ -71,26 +63,9 @@ class dispatch_core_manager { /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the issue queue interface - const tt_cxy_pair &prefetcher_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.prefetcher.has_value()) { - return assignment.prefetcher.value(); - } - // Issue queue interface is on the MMIO device - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); - CoreCoord issue_queue_coord = this->get_next_available_dispatch_core(mmio_device_id); - assignment.prefetcher = tt_cxy_pair(mmio_device_id, issue_queue_coord.x, issue_queue_coord.y); - log_dispatch_assignment("Prefetcher", assignment.prefetcher.value(), device_id, channel, cq_id); - return assignment.prefetcher.value(); - } - - bool is_prefetcher_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.prefetcher.has_value()) { - return true; - } - return false; - } + const tt_cxy_pair& prefetcher_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_prefetcher_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated to interface with prefetcher kernel running on mmio device. /// Prefetcher kernel on mmio device relays commands to prefetcher_d running on remote device. @@ -98,50 +73,18 @@ class dispatch_core_manager { /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the issue queue interface - const tt_cxy_pair &prefetcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.prefetcher_d.has_value()) { - return assignment.prefetcher_d.value(); - } - CoreCoord prefetch_d_coord = this->get_next_available_dispatch_core(device_id); - assignment.prefetcher_d = tt_cxy_pair(device_id, prefetch_d_coord.x, prefetch_d_coord.y); - log_dispatch_assignment("Prefetcher D", assignment.prefetcher_d.value(), device_id, channel, cq_id); - return assignment.prefetcher_d.value(); - } - - bool is_prefetcher_d_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.prefetcher_d.has_value()) { - return true; - } - return false; - } + const tt_cxy_pair& prefetcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_prefetcher_d_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated for multiplexing issue queue traffic to tunneler. /// @param device_id ID of the device that a fast dispatch command targets /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the mux core - const tt_cxy_pair &mux_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.mux.has_value()) { - return assignment.mux.value(); - } - // Mux interface is on the MMIO device - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); - CoreCoord mux_coord = this->get_next_available_dispatch_core(mmio_device_id); - assignment.mux = tt_cxy_pair(mmio_device_id, mux_coord.x, mux_coord.y); - log_dispatch_assignment("Mux", assignment.mux.value(), device_id, channel, cq_id); - return assignment.mux.value(); - } - - bool is_mux_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.mux.has_value()) { - return true; - } - return false; - } + const tt_cxy_pair& mux_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_mux_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated for multiplexing traffic back towards mmio chip. /// @param device_id ID of the device that a fast dispatch command targets @@ -149,92 +92,33 @@ class dispatch_core_manager { /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the mux_d core - const tt_cxy_pair &mux_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.mux_d.has_value()) { - return assignment.mux_d.value(); - } - // mux_d is on remote device - CoreCoord mux_d_coord = this->get_next_available_dispatch_core(device_id); - assignment.mux_d = tt_cxy_pair(device_id, mux_d_coord.x, mux_d_coord.y); - log_dispatch_assignment("Mux D", assignment.mux_d.value(), device_id, channel, cq_id); - return assignment.mux_d.value(); - } + const tt_cxy_pair& mux_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated for demultiplexing traffic to completion queues. /// @param device_id ID of the device that a fast dispatch command targets /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the mux core - const tt_cxy_pair &demux_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.demux.has_value()) { - return assignment.demux.value(); - } - // demux interface is on the MMIO device - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); - CoreCoord demux_coord = this->get_next_available_dispatch_core(mmio_device_id); - assignment.demux = tt_cxy_pair(mmio_device_id, demux_coord.x, demux_coord.y); - log_dispatch_assignment("Demux", assignment.demux.value(), device_id, channel, cq_id); - return assignment.demux.value(); - } - - bool is_demux_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.demux.has_value()) { - return true; - } - return false; - } + const tt_cxy_pair& demux_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_demux_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated for demultiplexing traffic on remote chip. /// @param device_id ID of the device that a fast dispatch command targets /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the demux_d core - const tt_cxy_pair &demux_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.demux_d.has_value()) { - return assignment.demux_d.value(); - } - // demux_d is on remote device - CoreCoord demux_d_coord = this->get_next_available_dispatch_core(device_id); - assignment.demux_d = tt_cxy_pair(device_id, demux_d_coord.x, demux_d_coord.y); - log_dispatch_assignment("Demux D", assignment.demux_d.value(), device_id, channel, cq_id); - return assignment.demux_d.value(); - } + const tt_cxy_pair& demux_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated for tunneling over ethernet. /// @param device_id ID of the device that a fast dispatch command targets /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the ethernet tunnel core - const tt_cxy_pair &tunneler_core(chip_id_t upstream_device_id, chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.tunneler.has_value()) { - return assignment.tunneler.value(); - } - - auto[us_core, ds_core] = tt::Cluster::instance().get_eth_tunnel_core(upstream_device_id, device_id, EthRouterMode::BI_DIR_TUNNELING); - - assignment.tunneler = us_core; - assignment.tunneler_d = ds_core; - - log_dispatch_assignment("Tunneler Remote", assignment.tunneler.value(), device_id, channel, cq_id, true); - log_dispatch_assignment("Tunneler Local", assignment.tunneler_d.value(), device_id, channel, cq_id, true); - return assignment.tunneler.value(); - } - - const tt_cxy_pair &us_tunneler_core_local(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.tunneler_d.has_value()) { - return assignment.tunneler_d.value(); - } - TT_ASSERT(false, "Device {} has no allocation for Local Upstream Tunneler Core.", device_id); - assignment.tunneler_d = tt_cxy_pair(0, 0, 0); - return assignment.tunneler_d.value(); - } + const tt_cxy_pair& tunneler_core( + chip_id_t upstream_device_id, chip_id_t device_id, uint16_t channel, uint8_t cq_id); + const tt_cxy_pair& us_tunneler_core_local(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel desginated to write to the completion queue region for a particular command queue /// Each command queue has one completion queue @@ -244,174 +128,62 @@ class dispatch_core_manager { /// @param channel assigned to the command queue /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the completion queue interface - const tt_cxy_pair &completion_queue_writer_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.completion_queue_writer.has_value()) { - return assignment.completion_queue_writer.value(); - } - // Completion queue interface is on the MMIO device - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); - CoreCoord completion_queue_coord = this->get_next_available_dispatch_core(mmio_device_id); - assignment.completion_queue_writer = tt_cxy_pair(mmio_device_id, completion_queue_coord.x, completion_queue_coord.y); - TT_ASSERT(not assignment.dispatcher.has_value(), "Command dispatcher core {} must match completion queue interface core for MMIO device {}", assignment.dispatcher.value().str(), device_id); - assignment.dispatcher = assignment.completion_queue_writer; - log_dispatch_assignment("Completion Queue Writer", assignment.completion_queue_writer.value(), device_id, channel, cq_id); - return assignment.completion_queue_writer.value(); - } - - bool is_completion_queue_writer_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.completion_queue_writer.has_value()) { - return true; - } - return false; - } + const tt_cxy_pair& completion_queue_writer_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_completion_queue_writer_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel designated to relay fast dispatch commands to worker cores from a particular command queue /// @param device_id ID of the device that should be running the command /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the dispatcher core - const tt_cxy_pair &dispatcher_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.dispatcher.has_value()) { - return assignment.dispatcher.value(); - } - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); - CoreCoord dispatcher_coord = this->get_next_available_dispatch_core(mmio_device_id); - assignment.dispatcher = tt_cxy_pair(mmio_device_id, dispatcher_coord.x, dispatcher_coord.y); - TT_ASSERT(not assignment.completion_queue_writer.has_value(), "Command dispatcher core must match completion queue interface core for MMIO device {}", device_id); - assignment.completion_queue_writer = assignment.dispatcher; - log_dispatch_assignment("Dispatcher", assignment.dispatcher.value(), device_id, channel, cq_id); - return assignment.dispatcher.value(); - } - - bool is_dispatcher_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.dispatcher.has_value()) { - return true; - } - return false; - } - - bool is_dispatcher_s_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - return assignment.dispatcher_s.has_value(); - } + const tt_cxy_pair& dispatcher_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_dispatcher_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + bool is_dispatcher_s_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id); /// @brief Gets the location of the kernel designated to relay fast dispatch commands to worker cores from a particular command queue /// @param device_id ID of the device that should be running the command /// @param channel assigned to the command queue where commands are enqueued /// @param cq_id ID of the command queue within the channel /// @return tt_cxy_pair logical location (chip + core coordinate) of the dispatcher_d core - const tt_cxy_pair &dispatcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.dispatcher_d.has_value()) { - return assignment.dispatcher_d.value(); - } - CoreCoord dispatcher_d_coord = this->get_next_available_dispatch_core(device_id); - assignment.dispatcher_d = tt_cxy_pair(device_id, dispatcher_d_coord.x, dispatcher_d_coord.y); - log_dispatch_assignment("Dispatcher D", assignment.dispatcher_d.value(), device_id, channel, cq_id); - return assignment.dispatcher_d.value(); - } - - const tt_cxy_pair &dispatcher_s_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { - dispatch_core_placement_t &assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; - if (assignment.dispatcher_s.has_value()) { - return assignment.dispatcher_s.value(); - } - CoreCoord dispatcher_s_coord; - if (this->get_dispatch_core_type(device_id) == CoreType::WORKER) { - chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); - if (mmio_device_id == device_id) { - // dispatch_s is on the same tensix core as dispatch_hd - dispatcher_s_coord = this->dispatcher_core(device_id, channel, cq_id); - } else { - // dispatch_s is on the same tensix as dispatch_d - dispatcher_s_coord = this->dispatcher_d_core(device_id, channel, cq_id); - } - } else { - dispatcher_s_coord = this->get_next_available_dispatch_core(device_id); - } - assignment.dispatcher_s = tt_cxy_pair(device_id, dispatcher_s_coord.x, dispatcher_s_coord.y); - log_dispatch_assignment("Dispatcher S", assignment.dispatcher_s.value(), device_id, channel, cq_id); - return assignment.dispatcher_s.value(); - } - - CoreType get_dispatch_core_type(chip_id_t device_id) { - return this->dispatch_core_config_by_device[device_id].get_core_type(); - } - - DispatchCoreConfig get_dispatch_core_config(chip_id_t device_id) { - return this->dispatch_core_config_by_device[device_id]; - } - - void add_dispatch_core_to_device(chip_id_t device_id, const CoreCoord& core) { - // TODO: remove this API, we should read the core descriptor once, should not have backdoors like this to add cores - auto& dispatch_cores = available_dispatch_cores_by_device.at(device_id); - if (std::find(dispatch_cores.begin(), dispatch_cores.end(), core) == dispatch_cores.end()) { - dispatch_cores.push_back(core); - } - } - - std::vector get_all_logical_dispatch_cores(chip_id_t device_id) { - return tt::get_logical_dispatch_cores(device_id, MAX_NUM_HW_CQS, this->dispatch_core_config_by_device[device_id]); - } - private: + const tt_cxy_pair& dispatcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + const tt_cxy_pair& dispatcher_s_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id); + + CoreType get_dispatch_core_type(chip_id_t device_id); + + DispatchCoreConfig get_dispatch_core_config(chip_id_t device_id); + + // TODO: remove this API, we should read the core descriptor once, should not have backdoors like this to add cores + void add_dispatch_core_to_device(chip_id_t device_id, const CoreCoord& core); + + std::vector get_all_logical_dispatch_cores(chip_id_t device_id); + +private: /// @brief dispatch_core_manager constructor initializes a list of cores per device that are designated for any dispatch functionality /// This list contains dispatch cores that have not been assigned to a particular dispatch function /// @param num_hw_cqs is used to get the correct collection of dispatch cores for a particular device /// @param dispatch_core_config specfies the core type that is designated for dispatch functionality - dispatch_core_manager(const DispatchCoreConfig &dispatch_core_config, uint8_t num_hw_cqs) { - this->reset_dispatch_core_manager(dispatch_core_config, num_hw_cqs); - } - + dispatch_core_manager(const DispatchCoreConfig& dispatch_core_config, uint8_t num_hw_cqs); /// @brief reset_dispatch_core_manager initializes vector of cores per device for dispatch kernels /// @param dispatch_core_config specfies the core type for dispatch kernels - void reset_dispatch_core_manager(const DispatchCoreConfig &dispatch_core_config, uint8_t num_hw_cqs) { - this->dispatch_core_assignments.clear(); - this->available_dispatch_cores_by_device.clear(); - this->dispatch_core_config_by_device.clear(); - for (chip_id_t device_id = 0; device_id < tt::Cluster::instance().number_of_devices(); device_id++) { - std::list &logical_dispatch_cores = this->available_dispatch_cores_by_device[device_id]; - for (const CoreCoord &logical_dispatch_core : - tt::get_logical_dispatch_cores(device_id, MAX_NUM_HW_CQS, dispatch_core_config)) { - logical_dispatch_cores.push_back(logical_dispatch_core); - } - - this->dispatch_core_config_by_device[device_id] = dispatch_core_config; - this->num_hw_cqs = num_hw_cqs; - } - } + void reset_dispatch_core_manager(const DispatchCoreConfig& dispatch_core_config, uint8_t num_hw_cqs); /// @brief getting any available dispatch core for a device /// @param device_id /// @return - CoreCoord get_next_available_dispatch_core(chip_id_t device_id) { - if (this->available_dispatch_cores_by_device.find(device_id) == this->available_dispatch_cores_by_device.end()) { - TT_THROW("Invalid device ID to assign dispatch cores {}", device_id); - } - if (this->available_dispatch_cores_by_device.at(device_id).empty()) { - TT_THROW("No more available dispatch cores on device {} to assign. Expand dispatch cores specified in core descriptor YAML", device_id); - } - CoreCoord avail_dispatch_core = this->available_dispatch_cores_by_device.at(device_id).front(); - this->available_dispatch_cores_by_device.at(device_id).pop_front(); - return avail_dispatch_core; - } - - void log_dispatch_assignment(std::string name, tt_cxy_pair &cxy, chip_id_t device_id, uint16_t channel, uint8_t cq_id, bool force_ethernet = false) { - log_debug( - tt::LogMetal, - "Allocated {} Core: {}({}) for Device {} Channel {} CQ ID {}", - name, - cxy.str(), - tt::Cluster::instance().get_virtual_coordinate_from_logical_coordinates(cxy, force_ethernet? CoreType::ETH : get_dispatch_core_type(cxy.chip)).str(), - device_id, - channel, - cq_id); - } - + CoreCoord get_next_available_dispatch_core(chip_id_t device_id); + + void log_dispatch_assignment( + std::string name, + tt_cxy_pair& cxy, + chip_id_t device_id, + uint16_t channel, + uint8_t cq_id, + bool force_ethernet = false); // {device ID : {channel (hugepage) : {cq_id : dispatch assignment}}} // Each device has an assigned hugepage at a specific channel that holds (up to 2) hardware command queues (represented by cq_id) diff --git a/tt_metal/api/tt-metalium/distributed.hpp b/tt_metal/api/tt-metalium/distributed.hpp index a94cbaa9ecc..96b3a23ed10 100644 --- a/tt_metal/api/tt-metalium/distributed.hpp +++ b/tt_metal/api/tt-metalium/distributed.hpp @@ -31,7 +31,12 @@ void WriteShard( std::vector& src, const Coordinate& coord, bool blocking = false) { - mesh_cq.enqueue_write_shard(mesh_buffer, src.data(), coord, blocking); + std::vector shard_data_transfers = {{ + .shard_coord = coord, + .host_data = src.data(), + .region = std::nullopt, + }}; + mesh_cq.enqueue_write_shards(mesh_buffer, shard_data_transfers, blocking); } template @@ -43,7 +48,12 @@ void ReadShard( bool blocking = true) { auto shard = mesh_buffer->get_device_buffer(coord); dst.resize(shard->page_size() * shard->num_pages() / sizeof(DType)); - mesh_cq.enqueue_read_shard(dst.data(), mesh_buffer, coord, blocking); + std::vector shard_data_transfers = {{ + .shard_coord = coord, + .host_data = dst.data(), + .region = std::nullopt, + }}; + mesh_cq.enqueue_read_shards(shard_data_transfers, mesh_buffer, blocking); } template diff --git a/tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp b/tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp new file mode 100644 index 00000000000..5c6aec97b59 --- /dev/null +++ b/tt_metal/api/tt-metalium/lightmetal_capture_utils.hpp @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "lightmetal/host_api_capture_helpers.hpp" +#include + +namespace tt::tt_metal { + +// Note: LightMetalCompare functions could have been inside host_api.hpp / command_queue.cpp but seems better +// to not make as visible, since these are APIs used at light-metal capture time for verification purposes. + +// clang-format off +/** + * Reads a buffer from the device and captures return data as golden inside Light Metal Binary, and optionally returns to user. + * When replaying Light Metal Binary, buffer is read and data is compared to the capture-time golden data. + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |----------------|-----------------------------------------------------------------------------------|-------------------------------------|----------------------------------------|----------| + * | cq | The command queue object which dispatches the command to the hardware | CommandQueue & | | Yes | + * | buffer | The device buffer we are reading from | Buffer & or std::shared_ptr | | Yes | + * | dst | The memory where the result will be stored, if provided | void* | | No | + */ +// clang-format on +void LightMetalCompareToCapture( + CommandQueue& cq, + const std::variant, std::shared_ptr>& buffer, + void* dst = nullptr); + +// clang-format off +/** + * Accepts user-supplied golden data, stored inside Light Metal Binary. + * When replaying Light Metal Binary, buffer is read and data is compared to the user-supplied golden data. + * + * Return value: void + * + * | Argument | Description | Type | Valid Range | Required | + * |----------------|-----------------------------------------------------------------------------------|-------------------------------------|----------------------------------------|----------| + * | cq | The command queue object which dispatches the command to the hardware | CommandQueue & | | Yes | + * | buffer | The device buffer we are reading from | Buffer & or std::shared_ptr | | Yes | + * | golden_data | User supplied expected/golden data for buffer | void* | | Yes | + */ +// clang-format on + +void LightMetalCompareToGolden( + CommandQueue& cq, + const std::variant, std::shared_ptr>& buffer, + void* golden_data); + +} // namespace tt::tt_metal diff --git a/tt_metal/api/tt-metalium/memcpy.hpp b/tt_metal/api/tt-metalium/memcpy.hpp index 0905032697e..298d8dd3dc0 100644 --- a/tt_metal/api/tt-metalium/memcpy.hpp +++ b/tt_metal/api/tt-metalium/memcpy.hpp @@ -32,6 +32,14 @@ static inline void memcpy_to_device(void* __restrict dst, const void* __restrict uint8_t* dst8 = (uint8_t*)dst; if (size_t num_lines = n / inner_blk_size) { + if ((uintptr_t)dst8 % sizeof(__m256i) != 0) { + __m128i blk = _mm_loadu_si128((const __m128i *)src8); + _mm_stream_si128((__m128i *)dst8, blk); + src8 += sizeof(__m128i); + dst8 += sizeof(__m128i); + n -= sizeof(__m128i); + num_lines = n / inner_blk_size; + } for (size_t i = 0; i < num_lines; ++i) { for (size_t j = 0; j < inner_loop; ++j) { __m256i blk = _mm256_loadu_si256((const __m256i*)src8); @@ -45,6 +53,14 @@ static inline void memcpy_to_device(void* __restrict dst, const void* __restrict if (n > 0) { if (size_t num_lines = n / sizeof(__m256i)) { + if ((uintptr_t)dst8 % sizeof(__m256i) != 0) { + __m128i blk = _mm_loadu_si128((const __m128i *)src8); + _mm_stream_si128((__m128i *)dst8, blk); + src8 += sizeof(__m128i); + dst8 += sizeof(__m128i); + n -= sizeof(__m128i); + num_lines = n / sizeof(__m256i); + } for (size_t i = 0; i < num_lines; ++i) { __m256i blk = _mm256_loadu_si256((const __m256i*)src8); _mm256_stream_si256((__m256i*)dst8, blk); diff --git a/tt_metal/api/tt-metalium/mesh_command_queue.hpp b/tt_metal/api/tt-metalium/mesh_command_queue.hpp index 38d13891095..61263207b9c 100644 --- a/tt_metal/api/tt-metalium/mesh_command_queue.hpp +++ b/tt_metal/api/tt-metalium/mesh_command_queue.hpp @@ -4,6 +4,8 @@ #pragma once +#include +#include "buffer.hpp" #include "command_queue_interface.hpp" #include "mesh_buffer.hpp" #include "mesh_device.hpp" @@ -21,20 +23,28 @@ class MeshCommandQueue { void populate_dispatch_core_type(); CoreCoord virtual_program_dispatch_core() const; CoreType dispatch_core_type() const; + // Helper functions for reading and writing individual shards void write_shard_to_device( - std::shared_ptr& shard_view, const void* src, tt::stl::Span sub_device_ids = {}); + std::shared_ptr& shard_view, + const void* src, + const BufferRegion& region, + tt::stl::Span sub_device_ids = {}); void read_shard_from_device( - std::shared_ptr& shard_view, void* dst, tt::stl::Span sub_device_ids = {}); + std::shared_ptr& shard_view, + void* dst, + const BufferRegion& region, + tt::stl::Span sub_device_ids = {}); + // Helper functions for read and write entire Sharded-MeshBuffers void write_sharded_buffer(const MeshBuffer& buffer, const void* src); void read_sharded_buffer(MeshBuffer& buffer, void* dst); std::array config_buffer_mgr_; std::array expected_num_workers_completed_; - MeshDevice* mesh_device_; - uint32_t id_; + MeshDevice* mesh_device_ = nullptr; + uint32_t id_ = 0; CoreCoord dispatch_core_; - CoreType dispatch_core_type_; + CoreType dispatch_core_type_ = CoreType::WORKER; public: MeshCommandQueue(MeshDevice* mesh_device, uint32_t id); @@ -42,16 +52,30 @@ class MeshCommandQueue { uint32_t id() const { return id_; } WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index) { return config_buffer_mgr_[index]; }; void enqueue_mesh_workload(MeshWorkload& mesh_workload, bool blocking); + + // Specifies host data to be written to or read from a MeshBuffer shard. + struct ShardDataTransfer { + Coordinate shard_coord; + void* host_data = nullptr; + std::optional region; + }; + // MeshBuffer Write APIs - void enqueue_write_shard( - std::shared_ptr& mesh_buffer, const void* host_data, const Coordinate& coord, bool blocking); void enqueue_write_shard_to_sub_grid( const MeshBuffer& buffer, const void* host_data, const LogicalDeviceRange& device_range, bool blocking); void enqueue_write_mesh_buffer(const std::shared_ptr& buffer, const void* host_data, bool blocking); + void enqueue_write_shards( + const std::shared_ptr& mesh_buffer, + const std::vector& shard_data_transfers, + bool blocking); + // MeshBuffer Read APIs - void enqueue_read_shard( - void* host_data, const std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking); void enqueue_read_mesh_buffer(void* host_data, const std::shared_ptr& buffer, bool blocking); + void enqueue_read_shards( + const std::vector& shard_data_transfers, + const std::shared_ptr& mesh_buffer, + bool blocking); + void finish(); void reset_worker_state( bool reset_launch_msg_state, diff --git a/tt_metal/api/tt-metalium/mesh_device.hpp b/tt_metal/api/tt-metalium/mesh_device.hpp index 9b768da3a32..ec04ada058f 100644 --- a/tt_metal/api/tt-metalium/mesh_device.hpp +++ b/tt_metal/api/tt-metalium/mesh_device.hpp @@ -64,6 +64,9 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this get_row_major_devices(const MeshShape& new_shape) const; + public: MeshDevice( std::shared_ptr mesh_handle, @@ -154,7 +157,6 @@ class MeshDevice : public IDevice, public std::enable_shared_from_this; + using Container = tt::stl::SmallVector; ShapeBase() { init(); }; explicit ShapeBase(const Container& shape) : value_(shape) { init(); } diff --git a/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp b/tt_metal/api/tt-metalium/small_vector.hpp similarity index 61% rename from ttnn/cpp/ttnn/tensor/shape/small_vector.hpp rename to tt_metal/api/tt-metalium/small_vector.hpp index 41e71c9792a..75c8759606a 100644 --- a/ttnn/cpp/ttnn/tensor/shape/small_vector.hpp +++ b/tt_metal/api/tt-metalium/small_vector.hpp @@ -6,13 +6,9 @@ #include -#include +#include "reflection.hpp" -#if TTNN_WITH_PYTHON_BINDINGS -#include -#endif - -namespace tt::tt_metal { +namespace tt::stl { static constexpr size_t SMALL_VECTOR_SIZE = 8; @@ -35,15 +31,16 @@ std::ostream& operator<<(std::ostream& os, const SmallVector -struct std::hash> { - size_t operator()(const ttnn::SmallVector& vec) const noexcept { +struct std::hash> { + size_t operator()(const tt::stl::SmallVector& vec) const noexcept { size_t hash = 0; for (const auto& element : vec) { hash = tt::stl::hash::detail::hash_objects(hash, element); @@ -53,23 +50,13 @@ struct std::hash> { }; template -struct fmt::formatter> { +struct fmt::formatter> { constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } - auto format(const tt::tt_metal::SmallVector& vector, format_context& ctx) const + auto format(const tt::stl::SmallVector& vector, format_context& ctx) const -> format_context::iterator { std::stringstream ss; ss << vector; return fmt::format_to(ctx.out(), "{}", ss.str()); } }; - -#if TTNN_WITH_PYTHON_BINDINGS -namespace PYBIND11_NAMESPACE { -namespace detail { -template -struct type_caster> - : list_caster, T> {}; -} // namespace detail -} // namespace PYBIND11_NAMESPACE -#endif diff --git a/tt_metal/common/CMakeLists.txt b/tt_metal/common/CMakeLists.txt index b34d189262b..551051ea52b 100644 --- a/tt_metal/common/CMakeLists.txt +++ b/tt_metal/common/CMakeLists.txt @@ -4,6 +4,7 @@ set(COMMON_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/core_descriptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/metal_soc_descriptor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/shape2d.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/shape_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tt_backend_api_types.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/work_split.cpp @@ -20,6 +21,7 @@ target_link_libraries( magic_enum fmt::fmt-header-only span + small_vector Metalium::Metal::STL umd::Firmware umd::device diff --git a/ttnn/cpp/ttnn/tensor/shape/shape_base.cpp b/tt_metal/common/shape_base.cpp similarity index 98% rename from ttnn/cpp/ttnn/tensor/shape/shape_base.cpp rename to tt_metal/common/shape_base.cpp index 687c2312f94..57e69bb49e6 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape_base.cpp +++ b/tt_metal/common/shape_base.cpp @@ -2,10 +2,10 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "assert.hpp" #include "shape_base.hpp" #include #include "fmt/color.h" -#include namespace tt::tt_metal { diff --git a/tt_metal/distributed/mesh_command_queue.cpp b/tt_metal/distributed/mesh_command_queue.cpp index cb409bdb4eb..89eaaff1b03 100644 --- a/tt_metal/distributed/mesh_command_queue.cpp +++ b/tt_metal/distributed/mesh_command_queue.cpp @@ -4,8 +4,10 @@ #include #include +#include #include +#include "buffer.hpp" #include "tt_metal/distributed/mesh_workload_utils.hpp" #include "tt_metal/impl/buffers/dispatch.hpp" #include "tt_metal/impl/program/dispatch.hpp" @@ -164,16 +166,21 @@ void MeshCommandQueue::finish() { } void MeshCommandQueue::write_shard_to_device( - std::shared_ptr& shard_view, const void* src, tt::stl::Span sub_device_ids) { + std::shared_ptr& shard_view, + const void* src, + const BufferRegion& region, + tt::stl::Span sub_device_ids) { auto device = shard_view->device(); - BufferRegion region(0, shard_view->size()); sub_device_ids = buffer_dispatch::select_sub_device_ids(mesh_device_, sub_device_ids); buffer_dispatch::write_to_device_buffer( src, *shard_view, region, id_, expected_num_workers_completed_, this->dispatch_core_type(), sub_device_ids); } void MeshCommandQueue::read_shard_from_device( - std::shared_ptr& shard_view, void* dst, tt::stl::Span sub_device_ids) { + std::shared_ptr& shard_view, + void* dst, + const BufferRegion& region, + tt::stl::Span sub_device_ids) { auto device = shard_view->device(); chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); @@ -181,7 +188,6 @@ void MeshCommandQueue::read_shard_from_device( bool exit_condition = false; - BufferRegion region(0, shard_view->size()); if (is_sharded(shard_view->buffer_layout())) { auto dispatch_params = buffer_dispatch::initialize_sharded_buf_read_dispatch_params( *shard_view, id_, expected_num_workers_completed_, region); @@ -211,23 +217,6 @@ void MeshCommandQueue::read_shard_from_device( } } -void MeshCommandQueue::enqueue_write_shard( - std::shared_ptr& mesh_buffer, const void* host_data, const Coordinate& coord, bool blocking) { - auto shard = mesh_buffer->get_device_buffer(coord); - this->write_shard_to_device(shard, host_data); - - if (blocking) { - this->finish(); - } -} - -void MeshCommandQueue::enqueue_read_shard( - void* host_data, const std::shared_ptr& mesh_buffer, const Coordinate& coord, bool blocking) { - TT_FATAL(blocking, "Only blocking reads are currently supported from MeshBuffer shards."); - auto shard = mesh_buffer->get_device_buffer(coord); - this->read_shard_from_device(shard, host_data); -} - void MeshCommandQueue::write_sharded_buffer(const MeshBuffer& buffer, const void* src) { auto global_buffer_shape = buffer.global_shard_spec().global_buffer_shape; auto global_buffer_size = buffer.global_shard_spec().global_size; @@ -269,26 +258,30 @@ void MeshCommandQueue::write_sharded_buffer(const MeshBuffer& buffer, const void replicated_device_y++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(replicated_device_y, replicated_device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); } } } else if (height_replicated or width_replicated) { if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { for (auto replicated_device_y = 0; replicated_device_y < num_devices_y; replicated_device_y++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(replicated_device_y, device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); } device_x++; } else { for (auto replicated_device_x = 0; replicated_device_x < num_devices_x; replicated_device_x++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, replicated_device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); } device_y++; } } else { auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x)); - this->write_shard_to_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, shard_data.data(), region); if (buffer.global_shard_spec().shard_orientation == ShardOrientation::ROW_MAJOR) { if (++device_x == num_devices_x) { device_x = 0; @@ -328,7 +321,9 @@ void MeshCommandQueue::read_sharded_buffer(MeshBuffer& buffer, void* dst) { for (std::size_t shard_y = 0; shard_y < num_shards_y; shard_y++) { for (std::size_t shard_x = 0; shard_x < num_shards_x; shard_x++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(device_y, device_x)); - this->read_shard_from_device(device_shard_view, shard_data.data()); + const BufferRegion region(0, device_shard_view->size()); + this->read_shard_from_device(device_shard_view, shard_data.data(), region); + uint32_t write_offset = shard_x * single_write_size + shard_y * stride_size_bytes * shard_shape.height(); uint32_t size_to_write = total_write_size_per_shard; uint32_t local_offset = 0; @@ -363,7 +358,8 @@ void MeshCommandQueue::enqueue_write_shard_to_sub_grid( for (std::size_t logical_y = device_range.start_coord.y; logical_y < device_range.end_coord.y + 1; logical_y++) { auto device_shard_view = buffer.get_device_buffer(Coordinate(logical_y, logical_x)); - this->write_shard_to_device(device_shard_view, host_data); + const BufferRegion region(0, device_shard_view->size()); + this->write_shard_to_device(device_shard_view, host_data, region); } } } else { @@ -387,6 +383,40 @@ void MeshCommandQueue::enqueue_read_mesh_buffer( this->read_sharded_buffer(*buffer, host_data); } +void MeshCommandQueue::enqueue_write_shards( + const std::shared_ptr& buffer, + const std::vector& shard_data_transfers, + bool blocking) { + // TODO: #17215 - this API is used by TTNN, as it currently implements rich ND sharding API for multi-devices. + // In the long run, the multi-device sharding API in Metal will change, and this will most likely be replaced. + for (const auto& shard_data_transfer : shard_data_transfers) { + auto device_shard_view = buffer->get_device_buffer(shard_data_transfer.shard_coord); + write_shard_to_device( + device_shard_view, + shard_data_transfer.host_data, + shard_data_transfer.region.value_or(BufferRegion(0, device_shard_view->size()))); + } + if (blocking) { + this->finish(); + } +} + +void MeshCommandQueue::enqueue_read_shards( + const std::vector& shard_data_transfers, + const std::shared_ptr& buffer, + bool blocking) { + // TODO: #17215 - this API is used by TTNN, as it currently implements rich ND sharding API for multi-devices. + // In the long run, the multi-device sharding API in Metal will change, and this will most likely be replaced. + const auto [num_rows, num_cols] = buffer->device()->shape(); + for (const auto& shard_data_transfer : shard_data_transfers) { + auto device_shard_view = buffer->get_device_buffer(shard_data_transfer.shard_coord); + read_shard_from_device( + device_shard_view, + shard_data_transfer.host_data, + shard_data_transfer.region.value_or(BufferRegion(0, device_shard_view->size()))); + } +} + void MeshCommandQueue::reset_worker_state( bool reset_launch_msg_state, uint32_t num_sub_devices, const vector_memcpy_aligned& go_signal_noc_data) { for (auto device : mesh_device_->get_devices()) { diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index abed02ea41a..e02498c3c28 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -252,11 +252,23 @@ size_t MeshDevice::num_cols() const { return mesh_shape_.num_cols; } MeshShape MeshDevice::shape() const { return mesh_shape_; } -void MeshDevice::reshape(const MeshShape& new_shape) { - TT_FATAL( - new_shape.num_rows * new_shape.num_cols == this->num_devices(), - "New shape must have the same number of devices as current shape"); - +std::vector MeshDevice::get_row_major_devices(const MeshShape& new_shape) const { + // MeshDeviceView requires devices to be provided as a 1D array in row-major order for the target mesh shape. + // The physical connectivity between devices must be preserved when reshaping. + // + // Example: + // Given 4 devices physically connected in a 2x2 grid like this: + // [0]--[1] + // | | + // [3]--[2] + // + // For a 1x4 mesh shape: + // - Devices must form a line: 0->1->2->3 + // - Row-major order will be: [0,1,2,3] + // + // For a 2x2 mesh shape: + // - Preserves original 2x2 physical connectivity + // - Row-major order will be: [0,1,3,2] std::unordered_map physical_device_id_to_linearized_index; for (size_t i = 0; i < this->num_devices(); i++) { physical_device_id_to_linearized_index[this->get_devices()[i]->id()] = i; @@ -264,6 +276,7 @@ void MeshDevice::reshape(const MeshShape& new_shape) { // From an MxN mesh, we can always reduce rank to a 1xM*N Line mesh. // However, going from a Line mesh to an MxN mesh is not always possible. + std::vector new_device_order; if (new_shape.num_rows != 1 and new_shape.num_cols != 1) { auto new_physical_device_ids = SystemMesh::instance().request_available_devices( @@ -285,10 +298,22 @@ void MeshDevice::reshape(const MeshShape& new_shape) { this->num_cols()); } } + for (size_t i = 0; i < new_physical_device_ids.size(); i++) { + new_device_order.push_back(this->get_device(new_physical_device_ids[i])); + } + } else { + new_device_order = view_->get_line_devices(); } + return new_device_order; +} + +void MeshDevice::reshape(const MeshShape& new_shape) { + TT_FATAL( + new_shape.num_rows * new_shape.num_cols == this->num_devices(), + "New shape must have the same number of devices as current shape"); mesh_shape_ = new_shape; - view_ = std::make_unique(scoped_devices_->get_devices(), mesh_shape_); + view_ = std::make_unique(this->get_row_major_devices(new_shape), new_shape); } bool MeshDevice::close() { @@ -628,12 +653,6 @@ void MeshDevice::init_command_queue_device() { TT_THROW("init_command_queue_device() is not supported on MeshDevice - use individual devices instead"); reference_device()->init_command_queue_device(); } -void MeshDevice::update_dispatch_cores_for_multi_cq_eth_dispatch() { - TT_THROW( - "update_dispatch_cores_for_multi_cq_eth_dispatch() is not supported on MeshDevice - use individual devices " - "instead"); - reference_device()->update_dispatch_cores_for_multi_cq_eth_dispatch(); -} void MeshDevice::synchronize() { // Nothing to synchronize, as all work is executed by MeshDevice is synchronous. } diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h index 62d85c771b0..fe51c0f060e 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h @@ -75,6 +75,46 @@ inline void llk_pack_hw_configure_disaggregated(std::uint32_t pack_output) { llk_pack_hw_configure(&llk_pack_params); } +template +inline void llk_pack_untilize_hw_configure( + const llk_pack_params_t* pack_params, const std::uint32_t face_r_dim, const std::uint32_t num_faces) { + const std::uint32_t output_id = get_output_id(pack_params->pack_output); + const std::uint32_t tile_c_dim = get_output_tile_c_dim(output_id); + const bool partial_face = get_output_partial_face(output_id); + const bool narrow_tile = get_output_narrow_tile(output_id); + + const std::uint32_t tile_size = get_local_cb_interface(output_id).fifo_page_size; + + _llk_pack_hw_configure_( + pack_src_format[output_id], + pack_dst_format[output_id], + tile_size, + face_r_dim, + tile_c_dim, + num_faces, + partial_face, + narrow_tile, + pack_params->relu_config.val); +} + +template < + bool untilize = false, + bool is_fp32_dest_acc_en = false, + ReluType relu_type = ReluType::NO_RELU, + std::uint32_t relu_threshold = 0, + bool tilize = false> +inline void llk_pack_untilize_hw_configure_disaggregated( + std::uint32_t pack_output, std::uint32_t face_r_dim = 16, std::uint32_t num_faces = 4) { + llk_pack_params_t llk_pack_params = { + .pack_output = pack_output, + .relu_config = { + .f = { + .ApplyRelu = (std::uint32_t)relu_type, + .Threshold = relu_threshold, + }}}; + llk_pack_untilize_hw_configure(&llk_pack_params, face_r_dim, num_faces); +} + template inline void llk_pack_reduce_hw_configure(const llk_pack_params_t* pack_params) { const std::uint32_t output_id = get_output_id(pack_params->pack_output); diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h index 890ea64a2bf..96440735c40 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h @@ -45,7 +45,7 @@ inline void llk_pack_hw_configure(const llk_pack_params_t* pack_params) { template < bool untilize = false, - bool is_fp32_dest_acc_en = false /*not used*/, + bool is_fp32_dest_acc_en = false /*unused*/, ReluType relu_type = ReluType::NO_RELU, std::uint32_t relu_threshold = 0, bool tilize = false /*unused*/> diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h index 13e8e5bbe98..df6985fb96d 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h @@ -73,6 +73,44 @@ inline void llk_pack_hw_configure_disaggregated(std::uint32_t pack_output) { llk_pack_hw_configure(&llk_pack_params); } +template +inline void llk_pack_untilize_hw_configure( + const llk_pack_params_t* pack_params, const std::uint32_t face_r_dim, const std::uint32_t num_faces) { + const std::uint32_t output_id = get_output_id(pack_params->pack_output); + const bool partial_face = get_output_partial_face(output_id); + const bool narrow_tile = get_output_narrow_tile(output_id); + + const std::uint32_t tile_size = get_local_cb_interface(output_id).fifo_page_size; + + _llk_pack_hw_configure_( + pack_src_format[output_id], + pack_dst_format[output_id], + tile_size, + face_r_dim, + num_faces, + partial_face, + narrow_tile, + pack_params->relu_config.val); +} + +template < + bool untilize = false, + bool is_fp32_dest_acc_en = false, + ReluType relu_type = ReluType::NO_RELU, + std::uint32_t relu_threshold = 0, + bool tilize = false /*unused*/> +inline void llk_pack_untilize_hw_configure_disaggregated( + std::uint32_t pack_output, std::uint32_t face_r_dim = 16, std::uint32_t num_faces = 4) { + llk_pack_params_t llk_pack_params = { + .pack_output = pack_output, + .relu_config = { + .f = { + .ApplyRelu = (std::uint32_t)relu_type, + .Threshold = relu_threshold, + }}}; + llk_pack_untilize_hw_configure(&llk_pack_params, face_r_dim, num_faces); +} + template inline void llk_pack_reduce_hw_configure(const llk_pack_params_t* pack_params) { const std::uint32_t output_id = get_output_id(pack_params->pack_output); @@ -198,7 +236,7 @@ inline void llk_pack_untilize_init( } else if constexpr (narrow_row) { TT_SETADCXX(p_setadc::PAC, row_num_datums - 1, 0x0); } else { - TT_SETADCXX(p_setadc::PAC, FACE_R_DIM - 1, 0x0); + TT_SETADCXX(p_setadc::PAC, FACE_C_DIM - 1, 0x0); } } diff --git a/tt_metal/hw/inc/blackhole/core_config.h b/tt_metal/hw/inc/blackhole/core_config.h index 9e4ba749e7b..beab0ab565c 100644 --- a/tt_metal/hw/inc/blackhole/core_config.h +++ b/tt_metal/hw/inc/blackhole/core_config.h @@ -25,5 +25,5 @@ constexpr uint8_t NumEthDispatchClasses = 2; constexpr uint8_t NumDramDispatchClasses = 1; constexpr uint8_t noc_size_x = 17; constexpr uint8_t noc_size_y = 12; -#define ALLOCATOR_ALIGNMENT 64 -#define LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT 6 +#define LOG_BASE_2_OF_DRAM_ALIGNMENT 6 +#define LOG_BASE_2_OF_L1_ALIGNMENT 4 diff --git a/tt_metal/hw/inc/blackhole/dev_mem_map.h b/tt_metal/hw/inc/blackhole/dev_mem_map.h index 075edd005ca..b97e3c5601b 100644 --- a/tt_metal/hw/inc/blackhole/dev_mem_map.h +++ b/tt_metal/hw/inc/blackhole/dev_mem_map.h @@ -48,7 +48,7 @@ ///////////// // Firmware/kernel code holes -#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 128) +#define MEM_BRISC_FIRMWARE_SIZE (5 * 1024 + 256) // TODO: perhaps put NCRISC FW in the scratch area and free 1.5K after init (GS/WH) #define MEM_NCRISC_FIRMWARE_SIZE 1536 #define MEM_TRISC0_FIRMWARE_SIZE 1536 diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index 8d1d95dec80..88038173b3f 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -12,6 +12,7 @@ #endif #include +#include #include #include "core_config.h" @@ -22,7 +23,6 @@ #include "eth_l1_address_map.h" #include "hostdevcommon/common_values.hpp" #include "risc_attribs.h" -#include "umd/device/tt_silicon_driver_common.hpp" #include "utils/utils.h" #include "debug/assert.h" #include "dev_msgs.h" @@ -119,6 +119,25 @@ FORCE_INLINE uint32_t get_bank_offset(uint32_t bank_index) { } } +template +FORCE_INLINE +constexpr uint32_t get_allocator_alignment() { + if constexpr (DRAM) { + return DRAM_ALIGNMENT; + } else { + return L1_ALIGNMENT; + } +} + +template +FORCE_INLINE +constexpr uint32_t get_log_base2_of_allocator_alignment() { + if constexpr (DRAM) { + return LOG_BASE_2_OF_DRAM_ALIGNMENT; + } else { + return LOG_BASE_2_OF_L1_ALIGNMENT; + } +} } // namespace interleaved_addr_gen // clang-format off @@ -630,8 +649,9 @@ uint64_t get_dram_noc_addr( uint8_t noc = noc_index) { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); - uint32_t addr = (bank_offset_index * align_power_of_2(page_size, ALLOCATOR_ALIGNMENT)) + bank_base_address + - offset + bank_to_dram_offset[bank_index]; + uint32_t addr = + (bank_offset_index * align_power_of_2(page_size, interleaved_addr_gen::get_allocator_alignment())) + + bank_base_address + offset + bank_to_dram_offset[bank_index]; uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; @@ -645,8 +665,9 @@ uint64_t get_l1_noc_addr( uint8_t noc = noc_index) { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); - uint32_t addr = (bank_offset_index * align_power_of_2(page_size, ALLOCATOR_ALIGNMENT)) + bank_base_address + - offset + bank_to_dram_offset[bank_index]; + uint32_t addr = + (bank_offset_index * align_power_of_2(page_size, interleaved_addr_gen::get_allocator_alignment())) + + bank_base_address + offset + bank_to_dram_offset[bank_index]; uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; @@ -1018,7 +1039,7 @@ template struct InterleavedAddrGen { uint32_t bank_base_address; // Base address for the whole tensor. const uint32_t page_size; // Num bytes in page. - const uint32_t aligned_page_size = align_power_of_2(page_size, ALLOCATOR_ALIGNMENT); + const uint32_t aligned_page_size = align_power_of_2(page_size, interleaved_addr_gen::get_allocator_alignment()); FORCE_INLINE uint32_t get_addr( @@ -1053,9 +1074,11 @@ struct InterleavedPow2AddrGen { const uint32_t bank_base_address; const uint32_t log_base_2_of_page_size; // WARNING: This struct is used for optimized get_noc_addr in which case // you know that bank_unit_size is a power of 2 - const uint32_t aligned_log_base_2_of_page_size = this->log_base_2_of_page_size > LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT + static constexpr uint32_t log_base_2_of_allocator_alignment = + interleaved_addr_gen::get_log_base2_of_allocator_alignment(); + const uint32_t aligned_log_base_2_of_page_size = this->log_base_2_of_page_size > log_base_2_of_allocator_alignment ? this->log_base_2_of_page_size - : LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT; + : log_base_2_of_allocator_alignment; FORCE_INLINE uint32_t get_addr( @@ -1168,9 +1191,11 @@ template struct InterleavedPow2AddrGenFast { uint32_t bank_base_address; // Base address for the whole tensor. const uint32_t log_base_2_of_page_size; // Num bytes in bank unit. - const uint32_t aligned_log_base_2_of_page_size = this->log_base_2_of_page_size > LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT + static constexpr uint32_t log_base_2_of_allocator_alignment = + interleaved_addr_gen::get_log_base2_of_allocator_alignment(); + const uint32_t aligned_log_base_2_of_page_size = this->log_base_2_of_page_size > log_base_2_of_allocator_alignment ? this->log_base_2_of_page_size - : LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT; + : log_base_2_of_allocator_alignment; FORCE_INLINE uint32_t get_addr( @@ -2022,6 +2047,51 @@ void noc_async_read_barrier_with_trid(uint32_t trid, uint8_t noc = noc_index) { WAYPOINT("NBTD"); } +inline void noc_async_write_one_packet_with_trid_set_state(std::uint64_t dst_noc_addr, uint8_t noc = noc_index) { +#ifndef ARCH_GRAYSKULL + WAYPOINT("NAWW"); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); + WAYPOINT("NAWD"); + uint32_t noc_cmd_field = NOC_CMD_CPY | NOC_CMD_WR | NOC_CMD_VC_STATIC | NOC_CMD_STATIC_VC(NOC_UNICAST_WRITE_VC) | + 0x0 | // (linked ? NOC_CMD_VC_LINKED : 0x0) + 0x0 | // (mcast ? (NOC_CMD_PATH_RESERVE | NOC_CMD_BRCST_PACKET) : 0x0) + NOC_CMD_RESP_MARKED; + + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CTRL, noc_cmd_field); +#ifdef ARCH_BLACKHOLE + // Handles writing to PCIe + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr >> 32) & 0x1000000F); +#endif + NOC_CMD_BUF_WRITE_REG( + noc, + write_cmd_buf, + NOC_RET_ADDR_COORDINATE, + (uint32_t)(dst_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); +#endif +} + +FORCE_INLINE void noc_async_write_one_packet_with_trid_with_state( + std::uint32_t src_local_l1_addr, + std::uint32_t dst_noc_addr, + std::uint32_t size, + std::uint32_t trid, + uint8_t noc = noc_index) { +#ifndef ARCH_GRAYSKULL + WAYPOINT("NWPW"); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); + WAYPOINT("NWPD"); + + // In order to sanitize, need to grab full noc addr + xfer size from state. + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION_WITH_ADDR_AND_SIZE_STATE(noc, dst_noc_addr, src_local_l1_addr); + + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_PACKET_TAG, NOC_PACKET_TAG_TRANSACTION_ID(trid)); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_TARG_ADDR_LO, src_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_LO, dst_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); +#endif +} + inline void noc_async_write_one_packet_with_trid( std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr, @@ -2032,7 +2102,18 @@ inline void noc_async_write_one_packet_with_trid( DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc, dst_noc_addr, src_local_l1_addr, size); #ifndef ARCH_GRAYSKULL ncrisc_noc_fast_write_any_len( - noc, write_cmd_buf, src_local_l1_addr, dst_noc_addr, size, NOC_UNICAST_WRITE_VC, false, false, 1, true, trid); + noc, + write_cmd_buf, + src_local_l1_addr, + dst_noc_addr, + size, + NOC_UNICAST_WRITE_VC, + false /*mcast*/, + false /*linked*/, + 1 /*num_dests*/, + false /*multicast_path_reserve*/, + false /*posted*/, + trid /*trid*/); #endif WAYPOINT("NAWD"); } diff --git a/tt_metal/hw/inc/debug/sanitize_noc.h b/tt_metal/hw/inc/debug/sanitize_noc.h index 1d3b9aa202f..12431bc266d 100644 --- a/tt_metal/hw/inc/debug/sanitize_noc.h +++ b/tt_metal/hw/inc/debug/sanitize_noc.h @@ -198,8 +198,12 @@ inline uint16_t debug_valid_eth_addr(uint64_t addr, uint64_t len, bool write) { if (addr + len > MEM_ETH_BASE + MEM_ETH_SIZE) { return DebugSanitizeNocAddrOverflow; } + constexpr uint64_t mem_mailbox_end = MEM_IERISC_MAILBOX_END < eth_l1_mem::address_map::ERISC_MEM_MAILBOX_END + ? MEM_IERISC_MAILBOX_END + : eth_l1_mem::address_map::ERISC_MEM_MAILBOX_END; + #if !defined(DISPATCH_KERNEL) || (DISPATCH_KERNEL == 0) - if (write && (addr < eth_l1_mem::address_map::ERISC_MEM_MAILBOX_END)) { + if (write && (addr < mem_mailbox_end)) { return DebugSanitizeNocAddrMailbox; } #endif diff --git a/tt_metal/hw/inc/ethernet/dataflow_api.h b/tt_metal/hw/inc/ethernet/dataflow_api.h index b300aef4715..2ee188b911b 100644 --- a/tt_metal/hw/inc/ethernet/dataflow_api.h +++ b/tt_metal/hw/inc/ethernet/dataflow_api.h @@ -215,6 +215,12 @@ void eth_send_bytes_over_channel_payload_only_unsafe( } } +FORCE_INLINE +void eth_send_bytes_over_channel_payload_only_unsafe_one_packet( + uint32_t src_addr, uint32_t dst_addr, uint32_t num_bytes) { + internal_::eth_send_packet_bytes_unsafe(0, src_addr, dst_addr, num_bytes); +} + /* * Sends the write completion signal to the receiver ethernet core, for transfers where the payload was already sent. * The second half of a full ethernet send. diff --git a/tt_metal/hw/inc/ethernet/tunneling.h b/tt_metal/hw/inc/ethernet/tunneling.h index fbbf252619b..37d1422d2f6 100644 --- a/tt_metal/hw/inc/ethernet/tunneling.h +++ b/tt_metal/hw/inc/ethernet/tunneling.h @@ -78,6 +78,15 @@ void eth_send_packet_unsafe(uint32_t q_num, uint32_t src_word_addr, uint32_t des eth_txq_reg_write(q_num, ETH_TXQ_CMD, ETH_TXQ_CMD_START_DATA); } +FORCE_INLINE +void eth_send_packet_bytes_unsafe(uint32_t q_num, uint32_t src_addr, uint32_t dest_addr, uint32_t num_bytes) { + ASSERT(eth_txq_reg_read(q_num, ETH_TXQ_CMD) == 0); + eth_txq_reg_write(q_num, ETH_TXQ_TRANSFER_START_ADDR, src_addr); + eth_txq_reg_write(q_num, ETH_TXQ_DEST_ADDR, dest_addr); + eth_txq_reg_write(q_num, ETH_TXQ_TRANSFER_SIZE_BYTES, num_bytes); + eth_txq_reg_write(q_num, ETH_TXQ_CMD, ETH_TXQ_CMD_START_DATA); +} + FORCE_INLINE void eth_write_remote_reg(uint32_t q_num, uint32_t reg_addr, uint32_t val) { while (eth_txq_reg_read(q_num, ETH_TXQ_CMD) != 0) { diff --git a/tt_metal/hw/inc/grayskull/core_config.h b/tt_metal/hw/inc/grayskull/core_config.h index 5f73abc2364..066d86376c0 100644 --- a/tt_metal/hw/inc/grayskull/core_config.h +++ b/tt_metal/hw/inc/grayskull/core_config.h @@ -17,5 +17,5 @@ constexpr uint8_t MaxProcessorsPerCoreType = 5; constexpr uint8_t NumTensixDispatchClasses = 3; constexpr uint8_t noc_size_x = 13; constexpr uint8_t noc_size_y = 12; -#define ALLOCATOR_ALIGNMENT 32 -#define LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT 5 +#define LOG_BASE_2_OF_DRAM_ALIGNMENT 5 +#define LOG_BASE_2_OF_L1_ALIGNMENT 4 diff --git a/tt_metal/hw/inc/wormhole/core_config.h b/tt_metal/hw/inc/wormhole/core_config.h index 491ab6bb54a..e1d0c168036 100644 --- a/tt_metal/hw/inc/wormhole/core_config.h +++ b/tt_metal/hw/inc/wormhole/core_config.h @@ -22,5 +22,5 @@ constexpr uint8_t NumTensixDispatchClasses = 3; constexpr uint8_t NumEthDispatchClasses = 1; constexpr uint8_t noc_size_x = 10; constexpr uint8_t noc_size_y = 12; -#define ALLOCATOR_ALIGNMENT 32 -#define LOG_BASE_2_OF_ALLOCATOR_ALIGNMENT 5 +#define LOG_BASE_2_OF_DRAM_ALIGNMENT 5 +#define LOG_BASE_2_OF_L1_ALIGNMENT 4 diff --git a/tt_metal/impl/CMakeLists.txt b/tt_metal/impl/CMakeLists.txt index 3ba20f30f52..c72409857bf 100644 --- a/tt_metal/impl/CMakeLists.txt +++ b/tt_metal/impl/CMakeLists.txt @@ -25,6 +25,7 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/host_runtime_commands.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/dispatch_query_manager.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/dispatch_core_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/dispatch_core_manager.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/hardware_command_queue.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/launch_message_ring_buffer_state.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dispatch/worker_config_buffer.cpp @@ -52,6 +53,10 @@ set(IMPL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/buffer_types_to_flatbuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_from_flatbuffer.cpp ${CMAKE_CURRENT_SOURCE_DIR}/flatbuffer/program_types_to_flatbuffer.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/lightmetal_replay.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/lightmetal_capture.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/lightmetal_capture_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal/host_api_capture_helpers.cpp ) # Include helper functions and generate headers from flatbuffer schemas diff --git a/tt_metal/impl/allocator/allocator.cpp b/tt_metal/impl/allocator/allocator.cpp index f1424478570..052c7a97423 100644 --- a/tt_metal/impl/allocator/allocator.cpp +++ b/tt_metal/impl/allocator/allocator.cpp @@ -32,7 +32,7 @@ void Allocator::init_one_bank_per_channel() { BufferType::DRAM, bank_offsets, dram_bank_size, - config_.alignment, + config_.dram_alignment, config_.dram_unreserved_base, config_.disable_interleaved); for (uint32_t bank_id = 0; bank_id < config_.num_dram_channels; bank_id++) { @@ -47,7 +47,7 @@ void Allocator::init_one_bank_per_channel() { BufferType::TRACE, bank_offsets, config_.trace_region_size, - config_.alignment, + config_.dram_alignment, dram_bank_size + config_.dram_unreserved_base, config_.disable_interleaved); for (uint32_t bank_id = 0; bank_id < config_.num_dram_channels; bank_id++) { @@ -68,7 +68,7 @@ void Allocator::init_one_bank_per_l1() { BufferType::L1, bank_offsets, l1_bank_size, - config_.alignment, + config_.l1_alignment, config_.l1_unreserved_base, config_.disable_interleaved); @@ -220,6 +220,18 @@ const std::vector& Allocator::get_bank_ids_from_logical_core( const AllocatorConfig& Allocator::get_config() const { return config_; } +uint32_t Allocator::get_alignment(BufferType buffer_type) const { + switch (buffer_type) { + case BufferType::DRAM: + case BufferType::TRACE: return config_.dram_alignment; + case BufferType::L1: + case BufferType::L1_SMALL: return config_.l1_alignment; + default: { + TT_THROW("Unsupported buffer type!"); + } + } +} + DeviceAddr Allocator::get_base_allocator_addr(const HalMemType& mem_type) const { switch (mem_type) { case HalMemType::DRAM: return config_.dram_unreserved_base; diff --git a/tt_metal/impl/allocator/bank_manager.cpp b/tt_metal/impl/allocator/bank_manager.cpp index 644e3a75318..7ebc5feed2a 100644 --- a/tt_metal/impl/allocator/bank_manager.cpp +++ b/tt_metal/impl/allocator/bank_manager.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "tt_metal/impl/allocator/algorithms/free_list_opt.hpp" namespace tt { @@ -51,7 +52,7 @@ BankManager::BankManager( } interleaved_address_limit_ = 0; validate_num_banks(bank_id_to_bank_offset_.size(), buffer_type_, disable_interleaved); - this->init_allocator(size_bytes, alignment_bytes, alloc_offset); + this->init_allocator(size_bytes, hal.get_alignment(HalMemType::DRAM), alloc_offset); } BankManager::BankManager( @@ -67,7 +68,7 @@ BankManager::BankManager( interleaved_address_limit_(interleaved_address_limit), alignment_bytes_(alignment_bytes) { validate_num_banks(bank_id_to_bank_offset_.size(), buffer_type_, disable_interleaved); - this->init_allocator(size_bytes, alignment_bytes, alloc_offset); + this->init_allocator(size_bytes, hal.get_alignment(HalMemType::DRAM), alloc_offset); } uint32_t BankManager::num_banks() const { return bank_id_to_bank_offset_.size(); } diff --git a/tt_metal/impl/allocator/l1_banking_allocator.cpp b/tt_metal/impl/allocator/l1_banking_allocator.cpp index 4840fc02098..b2cde04dd9c 100644 --- a/tt_metal/impl/allocator/l1_banking_allocator.cpp +++ b/tt_metal/impl/allocator/l1_banking_allocator.cpp @@ -185,7 +185,7 @@ void Allocator::init_compute_and_storage_l1_bank_manager() { // Storage only cores only need to reserve mailbox space to hold barriers uint32_t mem_mailbox_base = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::MAILBOX); uint32_t storage_core_unreserved_base = - ((mem_mailbox_base + config_.alignment - 1) / config_.alignment) * config_.alignment; + ((mem_mailbox_base + config_.l1_alignment - 1) / config_.l1_alignment) * config_.l1_alignment; // There is only l1_bank_size bytes available for L1 buffers to be allocated in uint64_t l1_bank_size = config_.storage_core_bank_size.has_value() @@ -201,7 +201,7 @@ void Allocator::init_compute_and_storage_l1_bank_manager() { bank_id_to_bank_offset, allocatable_l1_size, interleaved_address_limit, - config_.alignment, + config_.l1_alignment, config_.l1_unreserved_base, config_.disable_interleaved); @@ -215,7 +215,7 @@ void Allocator::init_compute_and_storage_l1_bank_manager() { small_bank_id_to_bank_offset, config_.l1_small_size, small_interleaved_address_limit, - config_.alignment, + config_.l1_alignment, small_alloc_offset, config_.disable_interleaved); } diff --git a/tt_metal/impl/buffers/buffer.cpp b/tt_metal/impl/buffers/buffer.cpp index 1e6986e730f..e615e87669c 100644 --- a/tt_metal/impl/buffers/buffer.cpp +++ b/tt_metal/impl/buffers/buffer.cpp @@ -532,7 +532,7 @@ DeviceAddr Buffer::bank_local_page_address(uint32_t bank_id, uint32_t page_index return this->address() + offset; } -uint32_t Buffer::alignment() const { return this->allocator_->get_config().alignment; } +uint32_t Buffer::alignment() const { return allocator_->get_alignment(this->buffer_type()); } DeviceAddr Buffer::aligned_page_size() const { return align(page_size(), this->alignment()); diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index af2e642de32..05b8e5fead8 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -12,6 +12,7 @@ #include #include #include +#include "lightmetal/lightmetal_capture.hpp" #include "tracy/Tracy.hpp" #include #include "dprint_server.hpp" @@ -51,6 +52,7 @@ Device::Device( id_(device_id), worker_thread_core_(worker_thread_core), completion_queue_reader_core_(completion_queue_reader_core), work_executor_(worker_thread_core, device_id) { ZoneScoped; + update_dispatch_cores_for_multi_cq_eth_dispatch(); this->initialize(num_hw_cqs, l1_small_size, trace_region_size, l1_bank_remap, minimal); } @@ -272,25 +274,28 @@ std::unique_ptr Device::initialize_allocator(size_t l1_small_size, si .dram_bank_offsets = {}, .dram_unreserved_base = hal.get_dev_addr(HalDramMemAddrType::DRAM_BARRIER) + hal.get_dev_size(HalDramMemAddrType::DRAM_BARRIER), - .l1_unreserved_base = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED), + .dram_alignment = hal.get_alignment(HalMemType::DRAM), + .l1_unreserved_base = align( + hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED), + hal.get_alignment(HalMemType::DRAM)), .worker_grid = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(logical_size.x - 1, logical_size.y - 1))), .worker_l1_size = static_cast(soc_desc.worker_l1_size), .storage_core_bank_size = get_storage_core_bank_size(id_, num_hw_cqs_, dispatch_core_config), - .l1_small_size = tt::align(l1_small_size, hal.get_alignment(HalMemType::L1)), - .trace_region_size = tt::align(trace_region_size, hal.get_alignment(HalMemType::DRAM)), + .l1_small_size = align(l1_small_size, hal.get_alignment(HalMemType::DRAM)), + .trace_region_size = align(trace_region_size, hal.get_alignment(HalMemType::DRAM)), .core_type_from_noc_coord_table = {}, // Populated later .worker_log_to_virtual_routing_x = tt::Cluster::instance().get_worker_logical_to_virtual_x(this->id()), .worker_log_to_virtual_routing_y = tt::Cluster::instance().get_worker_logical_to_virtual_y(this->id()), .l1_bank_remap = {l1_bank_remap.begin(), l1_bank_remap.end()}, .compute_grid = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(compute_size.x - 1, compute_size.y - 1))), - .alignment = std::max(hal.get_alignment(HalMemType::DRAM), hal.get_alignment(HalMemType::L1)), + .l1_alignment = hal.get_alignment(HalMemType::L1), .disable_interleaved = false}); TT_FATAL(config.l1_small_size < (config.storage_core_bank_size.has_value() ? config.storage_core_bank_size.value() : config.worker_l1_size - config.l1_unreserved_base), "Reserved size must be less than bank size"); TT_FATAL( - config.l1_small_size % config.alignment == 0, - "Reserved size must be aligned to allocator alignment {}", - config.alignment); + config.l1_small_size % config.l1_alignment == 0, + "Reserved size must be aligned to L1 allocator alignment {}", + config.l1_alignment); // Initialize dram_offsets from soc_descriptor for (auto channel = 0; channel < soc_desc.get_num_dram_views(); channel++) { config.dram_bank_offsets.push_back(soc_desc.get_address_offset(channel)); @@ -872,6 +877,7 @@ void Device::clear_l1_state() { hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::TILE_HEADER_BUFFER)); } // TODO: clear idle eriscs as well + tt::Cluster::instance().l1_barrier(this->id()); } void Device::compile_command_queue_programs() { @@ -1017,6 +1023,13 @@ bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t t log_debug(tt::LogMetal, "Running with {} cqs ", num_hw_cqs); TT_FATAL(num_hw_cqs > 0 and num_hw_cqs <= dispatch_core_manager::MAX_NUM_HW_CQS, "num_hw_cqs can be between 1 and {}", dispatch_core_manager::MAX_NUM_HW_CQS); this->using_fast_dispatch_ = false; + // Trying to preserve logic that was in device_pool.cpp + // However, I honestly don't understand it + if (!initialized_ && (num_hw_cqs_ != num_hw_cqs)) { + // The dispatch core manager was reset, since the number of CQs was toggled. + // Account for chip specific idle eth dispatch cores. + update_dispatch_cores_for_multi_cq_eth_dispatch(); + } this->num_hw_cqs_ = num_hw_cqs; constexpr uint32_t harvesting_map_bits = 12; constexpr uint32_t num_hw_cq_bits = 8; @@ -1451,6 +1464,13 @@ void Device::end_trace(const uint8_t cq_id, const uint32_t tid) { this->id_, active_sub_device_manager->id()); this->command_queues_[cq_id]->record_end(); + + // Capture Trace if light metal trace capturing is enabled. + auto& lm_capture_ctx = LightMetalCaptureContext::get(); + if (lm_capture_ctx.is_tracing()) { + lm_capture_ctx.capture_trace_descriptor(*trace_buffer->desc, tid); + } + Trace::initialize_buffer(this->command_queue(cq_id), trace_buffer); this->mark_allocations_unsafe(); }, diff --git a/tt_metal/impl/device/device_pool.cpp b/tt_metal/impl/device/device_pool.cpp index c8a870c8e4c..657be4dc5c3 100644 --- a/tt_metal/impl/device/device_pool.cpp +++ b/tt_metal/impl/device/device_pool.cpp @@ -159,8 +159,6 @@ void bind_current_thread_to_free_cores(const std::unordered_set& free_ } // namespace device_cpu_allocator DevicePool* DevicePool::_inst = nullptr; -// Should probably add a dispatch_core_manager.cpp and move this there -tt_metal::dispatch_core_manager* tt_metal::dispatch_core_manager::_inst = nullptr; void DevicePool::init_profiler_devices() const { #if defined(TRACY_ENABLE) @@ -306,7 +304,6 @@ void DevicePool::activate_device(chip_id_t id) { false, worker_core_thread_core, completion_queue_reader_core); - device->update_dispatch_cores_for_multi_cq_eth_dispatch(); if (!this->firmware_built_keys.contains(device->build_key())) { device->build_firmware(); this->firmware_built_keys.insert(device->build_key()); @@ -315,11 +312,6 @@ void DevicePool::activate_device(chip_id_t id) { } else { log_debug(tt::LogMetal, "DevicePool re-initialize device {}", id); if (not device->is_initialized()) { - if (device->num_hw_cqs() != num_hw_cqs) { - // The dispatch core manager was reset, since the number of CQs was toggled. - // Account for chip specific idle eth dispatch cores. - device->update_dispatch_cores_for_multi_cq_eth_dispatch(); - } device->initialize(num_hw_cqs, this->l1_small_size, this->trace_region_size, this->l1_bank_remap); if (!this->firmware_built_keys.contains(device->build_key())) { device->build_firmware(); diff --git a/tt_metal/impl/dispatch/dispatch_core_manager.cpp b/tt_metal/impl/dispatch/dispatch_core_manager.cpp new file mode 100644 index 00000000000..09b8f7e4b4a --- /dev/null +++ b/tt_metal/impl/dispatch/dispatch_core_manager.cpp @@ -0,0 +1,336 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dispatch_core_manager.hpp" + +#include "core_descriptor.hpp" +#include "core_coord.hpp" +#include +#include "dispatch_core_common.hpp" + +#include "tt_cluster.hpp" + +namespace tt::tt_metal { + +dispatch_core_manager* dispatch_core_manager::_inst = nullptr; + +void dispatch_core_manager::initialize(const DispatchCoreConfig& dispatch_core_config, uint8_t num_hw_cqs) noexcept { + log_debug(tt::LogMetal, "DevicePool initialize"); + if (_inst == nullptr) { + static dispatch_core_manager dispatch_core_manager(dispatch_core_config, num_hw_cqs); + _inst = &dispatch_core_manager; + } else if (_inst->dispatch_core_config_by_device[0] != dispatch_core_config or num_hw_cqs != _inst->num_hw_cqs) { + _inst->reset_dispatch_core_manager(dispatch_core_config, num_hw_cqs); + } +} + +dispatch_core_manager& dispatch_core_manager::instance() { + TT_ASSERT(dispatch_core_manager::_inst != nullptr, "Trying to get dispatch_core_manager without initializing it"); + return *dispatch_core_manager::_inst; +} + +const tt_cxy_pair& dispatch_core_manager::prefetcher_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.prefetcher.has_value()) { + return assignment.prefetcher.value(); + } + // Issue queue interface is on the MMIO device + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + CoreCoord issue_queue_coord = this->get_next_available_dispatch_core(mmio_device_id); + assignment.prefetcher = tt_cxy_pair(mmio_device_id, issue_queue_coord.x, issue_queue_coord.y); + log_dispatch_assignment("Prefetcher", assignment.prefetcher.value(), device_id, channel, cq_id); + return assignment.prefetcher.value(); +} + +bool dispatch_core_manager::is_prefetcher_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.prefetcher.has_value()) { + return true; + } + return false; +} + +const tt_cxy_pair& dispatch_core_manager::prefetcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.prefetcher_d.has_value()) { + return assignment.prefetcher_d.value(); + } + CoreCoord prefetch_d_coord = this->get_next_available_dispatch_core(device_id); + assignment.prefetcher_d = tt_cxy_pair(device_id, prefetch_d_coord.x, prefetch_d_coord.y); + log_dispatch_assignment("Prefetcher D", assignment.prefetcher_d.value(), device_id, channel, cq_id); + return assignment.prefetcher_d.value(); +} + +bool dispatch_core_manager::is_prefetcher_d_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.prefetcher_d.has_value()) { + return true; + } + return false; +} + +const tt_cxy_pair& dispatch_core_manager::mux_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.mux.has_value()) { + return assignment.mux.value(); + } + // Mux interface is on the MMIO device + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + CoreCoord mux_coord = this->get_next_available_dispatch_core(mmio_device_id); + assignment.mux = tt_cxy_pair(mmio_device_id, mux_coord.x, mux_coord.y); + log_dispatch_assignment("Mux", assignment.mux.value(), device_id, channel, cq_id); + return assignment.mux.value(); +} + +bool dispatch_core_manager::is_mux_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.mux.has_value()) { + return true; + } + return false; +} + +const tt_cxy_pair& dispatch_core_manager::mux_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.mux_d.has_value()) { + return assignment.mux_d.value(); + } + // mux_d is on remote device + CoreCoord mux_d_coord = this->get_next_available_dispatch_core(device_id); + assignment.mux_d = tt_cxy_pair(device_id, mux_d_coord.x, mux_d_coord.y); + log_dispatch_assignment("Mux D", assignment.mux_d.value(), device_id, channel, cq_id); + return assignment.mux_d.value(); +} + +const tt_cxy_pair& dispatch_core_manager::demux_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.demux.has_value()) { + return assignment.demux.value(); + } + // demux interface is on the MMIO device + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + CoreCoord demux_coord = this->get_next_available_dispatch_core(mmio_device_id); + assignment.demux = tt_cxy_pair(mmio_device_id, demux_coord.x, demux_coord.y); + log_dispatch_assignment("Demux", assignment.demux.value(), device_id, channel, cq_id); + return assignment.demux.value(); +} + +bool dispatch_core_manager::is_demux_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.demux.has_value()) { + return true; + } + return false; +} + +const tt_cxy_pair& dispatch_core_manager::demux_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.demux_d.has_value()) { + return assignment.demux_d.value(); + } + // demux_d is on remote device + CoreCoord demux_d_coord = this->get_next_available_dispatch_core(device_id); + assignment.demux_d = tt_cxy_pair(device_id, demux_d_coord.x, demux_d_coord.y); + log_dispatch_assignment("Demux D", assignment.demux_d.value(), device_id, channel, cq_id); + return assignment.demux_d.value(); +} + +const tt_cxy_pair& dispatch_core_manager::tunneler_core( + chip_id_t upstream_device_id, chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.tunneler.has_value()) { + return assignment.tunneler.value(); + } + + auto [us_core, ds_core] = + tt::Cluster::instance().get_eth_tunnel_core(upstream_device_id, device_id, EthRouterMode::BI_DIR_TUNNELING); + + assignment.tunneler = us_core; + assignment.tunneler_d = ds_core; + + log_dispatch_assignment("Tunneler Remote", assignment.tunneler.value(), device_id, channel, cq_id, true); + log_dispatch_assignment("Tunneler Local", assignment.tunneler_d.value(), device_id, channel, cq_id, true); + return assignment.tunneler.value(); +} + +const tt_cxy_pair& dispatch_core_manager::us_tunneler_core_local(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.tunneler_d.has_value()) { + return assignment.tunneler_d.value(); + } + TT_ASSERT(false, "Device {} has no allocation for Local Upstream Tunneler Core.", device_id); + assignment.tunneler_d = tt_cxy_pair(0, 0, 0); + return assignment.tunneler_d.value(); +} + +const tt_cxy_pair& dispatch_core_manager::completion_queue_writer_core( + chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.completion_queue_writer.has_value()) { + return assignment.completion_queue_writer.value(); + } + // Completion queue interface is on the MMIO device + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + CoreCoord completion_queue_coord = this->get_next_available_dispatch_core(mmio_device_id); + assignment.completion_queue_writer = + tt_cxy_pair(mmio_device_id, completion_queue_coord.x, completion_queue_coord.y); + TT_ASSERT( + not assignment.dispatcher.has_value(), + "Command dispatcher core {} must match completion queue interface core for MMIO device {}", + assignment.dispatcher.value().str(), + device_id); + assignment.dispatcher = assignment.completion_queue_writer; + log_dispatch_assignment( + "Completion Queue Writer", assignment.completion_queue_writer.value(), device_id, channel, cq_id); + return assignment.completion_queue_writer.value(); +} + +bool dispatch_core_manager::is_completion_queue_writer_core_allocated( + chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.completion_queue_writer.has_value()) { + return true; + } + return false; +} + +const tt_cxy_pair& dispatch_core_manager::dispatcher_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.dispatcher.has_value()) { + return assignment.dispatcher.value(); + } + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + CoreCoord dispatcher_coord = this->get_next_available_dispatch_core(mmio_device_id); + assignment.dispatcher = tt_cxy_pair(mmio_device_id, dispatcher_coord.x, dispatcher_coord.y); + TT_ASSERT( + not assignment.completion_queue_writer.has_value(), + "Command dispatcher core must match completion queue interface core for MMIO device {}", + device_id); + assignment.completion_queue_writer = assignment.dispatcher; + log_dispatch_assignment("Dispatcher", assignment.dispatcher.value(), device_id, channel, cq_id); + return assignment.dispatcher.value(); +} + +bool dispatch_core_manager::is_dispatcher_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.dispatcher.has_value()) { + return true; + } + return false; +} + +bool dispatch_core_manager::is_dispatcher_s_core_allocated(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + return assignment.dispatcher_s.has_value(); +} + +const tt_cxy_pair& dispatch_core_manager::dispatcher_d_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.dispatcher_d.has_value()) { + return assignment.dispatcher_d.value(); + } + CoreCoord dispatcher_d_coord = this->get_next_available_dispatch_core(device_id); + assignment.dispatcher_d = tt_cxy_pair(device_id, dispatcher_d_coord.x, dispatcher_d_coord.y); + log_dispatch_assignment("Dispatcher D", assignment.dispatcher_d.value(), device_id, channel, cq_id); + return assignment.dispatcher_d.value(); +} + +const tt_cxy_pair& dispatch_core_manager::dispatcher_s_core(chip_id_t device_id, uint16_t channel, uint8_t cq_id) { + dispatch_core_placement_t& assignment = this->dispatch_core_assignments[device_id][channel][cq_id]; + if (assignment.dispatcher_s.has_value()) { + return assignment.dispatcher_s.value(); + } + CoreCoord dispatcher_s_coord; + if (this->get_dispatch_core_type(device_id) == CoreType::WORKER) { + chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); + if (mmio_device_id == device_id) { + // dispatch_s is on the same tensix core as dispatch_hd + dispatcher_s_coord = this->dispatcher_core(device_id, channel, cq_id); + } else { + // dispatch_s is on the same tensix as dispatch_d + dispatcher_s_coord = this->dispatcher_d_core(device_id, channel, cq_id); + } + } else { + dispatcher_s_coord = this->get_next_available_dispatch_core(device_id); + } + assignment.dispatcher_s = tt_cxy_pair(device_id, dispatcher_s_coord.x, dispatcher_s_coord.y); + log_dispatch_assignment("Dispatcher S", assignment.dispatcher_s.value(), device_id, channel, cq_id); + return assignment.dispatcher_s.value(); +} + +CoreType dispatch_core_manager::get_dispatch_core_type(chip_id_t device_id) { + return this->dispatch_core_config_by_device[device_id].get_core_type(); +} + +DispatchCoreConfig dispatch_core_manager::get_dispatch_core_config(chip_id_t device_id) { + return this->dispatch_core_config_by_device[device_id]; +} + +void dispatch_core_manager::add_dispatch_core_to_device(chip_id_t device_id, const CoreCoord& core) { + // TODO: remove this API, we should read the core descriptor once, should not have backdoors like this to add cores + auto& dispatch_cores = available_dispatch_cores_by_device.at(device_id); + if (std::find(dispatch_cores.begin(), dispatch_cores.end(), core) == dispatch_cores.end()) { + dispatch_cores.push_back(core); + } +} + +std::vector dispatch_core_manager::get_all_logical_dispatch_cores(chip_id_t device_id) { + return tt::get_logical_dispatch_cores(device_id, MAX_NUM_HW_CQS, this->dispatch_core_config_by_device[device_id]); +} + +// private methods + +dispatch_core_manager::dispatch_core_manager(const DispatchCoreConfig& dispatch_core_config, uint8_t num_hw_cqs) { + this->reset_dispatch_core_manager(dispatch_core_config, num_hw_cqs); +} + +void dispatch_core_manager::reset_dispatch_core_manager( + const DispatchCoreConfig& dispatch_core_config, uint8_t num_hw_cqs) { + this->dispatch_core_assignments.clear(); + this->available_dispatch_cores_by_device.clear(); + this->dispatch_core_config_by_device.clear(); + for (chip_id_t device_id = 0; device_id < tt::Cluster::instance().number_of_devices(); device_id++) { + std::list& logical_dispatch_cores = this->available_dispatch_cores_by_device[device_id]; + for (const CoreCoord& logical_dispatch_core : + tt::get_logical_dispatch_cores(device_id, MAX_NUM_HW_CQS, dispatch_core_config)) { + logical_dispatch_cores.push_back(logical_dispatch_core); + } + + this->dispatch_core_config_by_device[device_id] = dispatch_core_config; + this->num_hw_cqs = num_hw_cqs; + } +} + +CoreCoord dispatch_core_manager::get_next_available_dispatch_core(chip_id_t device_id) { + if (this->available_dispatch_cores_by_device.find(device_id) == this->available_dispatch_cores_by_device.end()) { + TT_THROW("Invalid device ID to assign dispatch cores {}", device_id); + } + if (this->available_dispatch_cores_by_device.at(device_id).empty()) { + TT_THROW( + "No more available dispatch cores on device {} to assign. Expand dispatch cores specified in core " + "descriptor YAML", + device_id); + } + CoreCoord avail_dispatch_core = this->available_dispatch_cores_by_device.at(device_id).front(); + this->available_dispatch_cores_by_device.at(device_id).pop_front(); + return avail_dispatch_core; +} + +void dispatch_core_manager::log_dispatch_assignment( + std::string name, tt_cxy_pair& cxy, chip_id_t device_id, uint16_t channel, uint8_t cq_id, bool force_ethernet) { + log_debug( + tt::LogMetal, + "Allocated {} Core: {}({}) for Device {} Channel {} CQ ID {}", + name, + cxy.str(), + tt::Cluster::instance() + .get_virtual_coordinate_from_logical_coordinates( + cxy, force_ethernet ? CoreType::ETH : get_dispatch_core_type(cxy.chip)) + .str(), + device_id, + channel, + cq_id); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/dispatch/dispatch_query_manager.cpp b/tt_metal/impl/dispatch/dispatch_query_manager.cpp index 82bfa7d0a04..e49af46ef7e 100644 --- a/tt_metal/impl/dispatch/dispatch_query_manager.cpp +++ b/tt_metal/impl/dispatch/dispatch_query_manager.cpp @@ -4,6 +4,8 @@ #include "tt_metal/impl/dispatch/dispatch_query_manager.hpp" +#include "tt_cluster.hpp" + namespace { tt::tt_metal::DispatchCoreConfig dispatch_core_config() { diff --git a/tt_metal/impl/dispatch/dispatch_query_manager.hpp b/tt_metal/impl/dispatch/dispatch_query_manager.hpp index 25ca7307bc7..e01cae1d068 100644 --- a/tt_metal/impl/dispatch/dispatch_query_manager.hpp +++ b/tt_metal/impl/dispatch/dispatch_query_manager.hpp @@ -3,7 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 #include -#include namespace tt::tt_metal { diff --git a/tt_metal/impl/dispatch/host_runtime_commands.cpp b/tt_metal/impl/dispatch/host_runtime_commands.cpp index 68eb075e998..e1e0dfa8b5b 100644 --- a/tt_metal/impl/dispatch/host_runtime_commands.cpp +++ b/tt_metal/impl/dispatch/host_runtime_commands.cpp @@ -43,6 +43,7 @@ #include #include +#include "lightmetal/host_api_capture_helpers.hpp" using namespace tt::tt_metal; @@ -513,6 +514,8 @@ void EnqueueReadBuffer( const std::variant, std::shared_ptr>& buffer, void* dst, bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueReadBuffer, cq, buffer, dst, blocking); Buffer& buffer_obj = detail::GetBufferObject(buffer); BufferRegion region(0, buffer_obj.size()); EnqueueReadSubBuffer(cq, buffer, dst, region, blocking); @@ -543,6 +546,8 @@ void EnqueueWriteBuffer( const std::variant, std::shared_ptr>& buffer, HostDataType src, bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueWriteBuffer, cq, buffer, src, blocking); Buffer& buffer_obj = detail::GetBufferObject(buffer); BufferRegion region(0, buffer_obj.size()); EnqueueWriteSubBuffer(cq, buffer, std::move(src), region, blocking); @@ -562,6 +567,8 @@ void EnqueueWriteSubBuffer( void EnqueueProgram(CommandQueue& cq, Program& program, bool blocking) { ZoneScoped; + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueProgram, cq, program, blocking); detail::DispatchStateCheck(true); IDevice* device = cq.device(); @@ -632,6 +639,8 @@ bool EventQuery(const std::shared_ptr& event) { } void Finish(CommandQueue& cq, tt::stl::Span sub_device_ids) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureFinish, cq, sub_device_ids); detail::DispatchStateCheck(true); cq.finish(sub_device_ids); TT_ASSERT( @@ -643,6 +652,8 @@ void Finish(CommandQueue& cq, tt::stl::Span sub_device_ids) { } void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureEnqueueTrace, cq, trace_id, blocking); detail::DispatchStateCheck(true); TT_FATAL(cq.device()->get_trace(trace_id) != nullptr, "Trace instance {} must exist on device", trace_id); cq.enqueue_trace(trace_id, blocking); diff --git a/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp index 0c3f4c3822b..33d1fe52571 100644 --- a/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp +++ b/tt_metal/impl/flatbuffer/buffer_types_to_flatbuffer.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "flatbuffer/buffer_types_to_flatbuffer.hpp" +#include "lightmetal/lightmetal_capture.hpp" // For LightMetalCaptureContext namespace tt::tt_metal { @@ -54,10 +55,8 @@ flatbuffers::Offset to_flatbuffer( }; // Optional shadow buffer for dynamically allocated CBs, get global_id or use 0 as none/nullptr. - // auto& ctx = LightMetalCaptureContext::Get(); - // auto shadow_buf_global_id = config.shadow_global_buffer ? ctx.GetGlobalId(config.shadow_global_buffer) : 0; - // TODO (kmabee) - Uncomment above code once capture library is merged. Temp hack here for now. - uint32_t shadow_buf_global_id = 0; + auto& ctx = LightMetalCaptureContext::get(); + auto shadow_buf_global_id = config.shadow_global_buffer ? ctx.get_global_id(config.shadow_global_buffer) : 0; // Create the FlatBuffer object return flatbuffer::CreateCircularBufferConfig( diff --git a/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp index f8176eb0f98..930ebe230e7 100644 --- a/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp +++ b/tt_metal/impl/flatbuffer/program_types_from_flatbuffer.hpp @@ -54,8 +54,7 @@ std::variant kernel_config_fr return from_flatbuffer(cmd->kernel_config_as_DataMovementConfig()); case flatbuffer::KernelConfig::ComputeConfig: return from_flatbuffer(cmd->kernel_config_as_ComputeConfig()); case flatbuffer::KernelConfig::EthernetConfig: return from_flatbuffer(cmd->kernel_config_as_EthernetConfig()); - case flatbuffer::KernelConfig::NONE: - throw std::runtime_error("Unhandled KernelConfig type in from_flatbuffer."); + case flatbuffer::KernelConfig::NONE: TT_THROW("Unhandled KernelConfig type in from_flatbuffer."); } TT_THROW("Unhandled KernelConfig type in from_flatbuffer."); } diff --git a/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp index a3d8e875819..6c8f1570604 100644 --- a/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp +++ b/tt_metal/impl/flatbuffer/program_types_to_flatbuffer.cpp @@ -4,7 +4,9 @@ #include "flatbuffer/base_types_to_flatbuffer.hpp" #include "flatbuffer/program_types_to_flatbuffer.hpp" +#include "lightmetal/lightmetal_capture.hpp" // For LightMetalCaptureContext #include + namespace tt::tt_metal { // Original types defined in core_coord.hpp @@ -155,10 +157,8 @@ flatbuffers::Offset create_runtime_arg( return builder.CreateStruct(tt_metal::flatbuffer::UInt32Value{arg_value}).Union(); }, [&](Buffer* arg_value) -> flatbuffers::Offset { - // auto& ctx = LightMetalCaptureContext::Get(); - // uint32_t buffer_global_id = ctx.GetGlobalId(arg_value); - // TODO (kmabee) - Uncomment above code once capture library is merged. Temp hack here for now. - uint32_t buffer_global_id = 0; + auto& ctx = LightMetalCaptureContext::get(); + uint32_t buffer_global_id = ctx.get_global_id(arg_value); value_type = flatbuffer::RuntimeArgValue::BufferGlobalId; return builder.CreateStruct(tt_metal::flatbuffer::BufferGlobalId{buffer_global_id}).Union(); }}, diff --git a/tt_metal/impl/lightmetal/host_api_capture_helpers.cpp b/tt_metal/impl/lightmetal/host_api_capture_helpers.cpp new file mode 100644 index 00000000000..9d4905bb2c6 --- /dev/null +++ b/tt_metal/impl/lightmetal/host_api_capture_helpers.cpp @@ -0,0 +1,394 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include "lightmetal/host_api_capture_helpers.hpp" +#include "command_generated.h" +#include "lightmetal/lightmetal_capture.hpp" +#include "flatbuffer/base_types_to_flatbuffer.hpp" +#include "flatbuffer/program_types_to_flatbuffer.hpp" +#include "flatbuffer/buffer_types_to_flatbuffer.hpp" + +namespace tt::tt_metal { + +////////////////////////////////////////////////////////////// +// Debug Code // +////////////////////////////////////////////////////////////// + +namespace { +// This can be useful for debug. Not all data types are currently supported, can use this during developmenmt. +void PrintHostDataType(const HostDataType& data) { + std::visit( + tt::stl::overloaded{ + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const std::shared_ptr>& value) { + log_info(tt::LogMetalTrace, "HostDataType contains: std::shared_ptr>"); + }, + [](const void* value) { log_info(tt::LogMetalTrace, "HostDataType contains: const void*"); }, + [](auto&&) { log_info(tt::LogMetalTrace, "HostDataType contains: Unknown type"); }}, + data); +} +} // namespace + +////////////////////////////////////////////////////////////// +// Host API tracing helper functions // +////////////////////////////////////////////////////////////// + +// Generic helper to build command and add to vector of cmds (CQ) - no need to make public +namespace { +void CaptureCommand(tt::tt_metal::flatbuffer::CommandType cmd_type, ::flatbuffers::Offset fb_offset) { + auto& ctx = LightMetalCaptureContext::get(); + ctx.get_cmds_vector().push_back(tt::tt_metal::flatbuffer::CreateCommand(ctx.get_builder(), cmd_type, fb_offset)); +} +} // namespace + +void CaptureReplayTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id, bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: cq_id: {} trace_id: {} blocking: {}", __FUNCTION__, cq_id, trace_id, blocking); + auto cmd = tt::tt_metal::flatbuffer::CreateReplayTraceCommand(ctx.get_builder(), cq_id, trace_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::ReplayTraceCommand, cmd.Union()); +} + +void CaptureEnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: cq_id: {} trace_id: {} blocking: {}", __FUNCTION__, cq.id(), trace_id, blocking); + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueTraceCommand(ctx.get_builder(), cq.id(), trace_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueTraceCommand, cmd.Union()); +} + +void CaptureLoadTrace(IDevice* device, uint8_t cq_id, uint32_t trace_id) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: cq_id: {} trace_id: {}", __FUNCTION__, cq_id, trace_id); + auto cmd = tt::tt_metal::flatbuffer::CreateLoadTraceCommand(ctx.get_builder(), trace_id, cq_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::LoadTraceCommand, cmd.Union()); +} + +void CaptureReleaseTrace(IDevice* device, uint32_t trace_id) { + auto& ctx = LightMetalCaptureContext::get(); + log_debug(tt::LogMetalTrace, "{}: trace_id: {}", __FUNCTION__, trace_id); + auto cmd = tt::tt_metal::flatbuffer::CreateReleaseTraceCommand(ctx.get_builder(), trace_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::ReleaseTraceCommand, cmd.Union()); +} + +void CaptureCreateBuffer(const std::shared_ptr& buffer, const InterleavedBufferConfig& config) { + auto& ctx = LightMetalCaptureContext::get(); + + uint32_t buffer_global_id = ctx.add_to_map(buffer.get()); + log_debug( + tt::LogMetalTrace, + "{}: size: {} page_size: {} buffer_type: {} buffer_layout: {} buffer_global_id: {}", + __FUNCTION__, + config.size, + config.page_size, + config.buffer_type, + config.buffer_layout, + buffer_global_id); + + assert(config.device->id() == 0 && "multichip not supported yet"); + auto buffer_config_offset = tt::tt_metal::flatbuffer::CreateInterleavedBufferConfig( + ctx.get_builder(), + config.device->id(), + config.size, + config.page_size, + to_flatbuffer(config.buffer_type), + to_flatbuffer(config.buffer_layout)); + auto cmd = + tt::tt_metal::flatbuffer::CreateCreateBufferCommand(ctx.get_builder(), buffer_global_id, buffer_config_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateBufferCommand, cmd.Union()); +} + +void CaptureDeallocateBuffer(Buffer& buffer) { + auto& ctx = LightMetalCaptureContext::get(); + + // Kind of a workaround, but Program Binaries buffer is created via Buffer::create() but can be + // deallocated on Program destruction while capturing is still enabled depending on test structure (scope) + // so let's just not capture these DeallocateBuffer() calls since they will occur on playback naturally. + if (!ctx.is_in_map(&buffer)) { + log_debug(tt::LogMetalTrace, "Cannot capture DeallocateBuffer() without CreateBuffer() - ignoring."); + return; + } + + auto buffer_global_id = ctx.get_global_id(&buffer); + + log_debug( + tt::LogMetalTrace, + "{}: buffer_global_id: {} size: {} address: {}", + __FUNCTION__, + buffer_global_id, + buffer.size(), + buffer.address()); + + auto cmd = tt::tt_metal::flatbuffer::CreateDeallocateBufferCommand(ctx.get_builder(), buffer_global_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::DeallocateBufferCommand, cmd.Union()); +} + +void CaptureEnqueueWriteBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + HostDataType src, + bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.get_global_id(buffer_ptr); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} buffer_global_id: {}", __FUNCTION__, cq_global_id, buffer_global_id); + // PrintHostDataType(src); // Debug + + // TODO (kmabee) - Currently support limited data formats. Long term we might not store data in flatbuffer, + // but have it provided at runtime so just do what's easiest here and support few types for now. + ::flatbuffers::Offset<::flatbuffers::Vector> src_vector; + if (auto* uint32_vec = std::get_if>>(&src)) { + src_vector = ctx.get_builder().CreateVector(**uint32_vec); + } else if (auto* uint16_vec = std::get_if>>(&src)) { + // Convert uint16_t to uint32_t before creating the FlatBuffers vector + std::vector converted(uint16_vec->get()->begin(), uint16_vec->get()->end()); + src_vector = ctx.get_builder().CreateVector(converted); + } else if (auto* void_ptr = std::get_if(&src)) { + // Assuming the void* points to a buffer of uint32_t values. Infer size, cast to uint32_t. + size_t num_elements = buffer_ptr->size() / sizeof(uint32_t); + auto uint32_data = static_cast(*void_ptr); + src_vector = ctx.get_builder().CreateVector(uint32_data, num_elements); + } else { + TT_THROW("Unsupported HostDataType for captureEnqueueWriteBuffer()"); + } + + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueWriteBufferCommand( + ctx.get_builder(), cq_global_id, buffer_global_id, src_vector, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueWriteBufferCommand, cmd.Union()); +} + +void CaptureEnqueueReadBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* dst, + bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.get_global_id(buffer_ptr); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} buffer_global_id: {}", __FUNCTION__, cq_global_id, buffer_global_id); + + // Idea store a read_global_id to keep track of read results. + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueReadBufferCommand( + ctx.get_builder(), cq_global_id, buffer_global_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueReadBufferCommand, cmd.Union()); +} + +void CaptureFinish(CommandQueue& cq, tt::stl::Span sub_device_ids) { + auto& ctx = LightMetalCaptureContext::get(); + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + + // Use to_flatbuffer to convert SubDeviceIds to FlatBuffer vector + auto fb_sub_device_ids = to_flatbuffer(ctx.get_builder(), sub_device_ids); + + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} sub_devices: {}", __FUNCTION__, cq_global_id, sub_device_ids.size()); + auto cmd = tt::tt_metal::flatbuffer::CreateFinishCommand(ctx.get_builder(), cq_global_id, fb_sub_device_ids); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::FinishCommand, cmd.Union()); +} + +void CaptureCreateProgram(Program& program) { + auto& ctx = LightMetalCaptureContext::get(); + uint32_t program_global_id = ctx.add_to_map(&program); + log_debug(tt::LogMetalTrace, "{}: program_global_id: {}", __FUNCTION__, program_global_id); + + auto cmd = tt::tt_metal::flatbuffer::CreateCreateProgramCommand(ctx.get_builder(), program_global_id); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateProgramCommand, cmd.Union()); +} + +void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking) { + auto& ctx = LightMetalCaptureContext::get(); + + // When Metal Trace is enabled, skip EnqueueProgram capture (replaced with LoadTrace + ReplayTrace) + if (cq.sysmem_manager().get_bypass_mode()) { + return; + } + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t program_global_id = ctx.get_global_id(&program); + log_debug( + tt::LogMetalTrace, "{}: cq_global_id: {} program_global_id: {}", __FUNCTION__, cq_global_id, program_global_id); + + auto cmd = tt::tt_metal::flatbuffer::CreateEnqueueProgramCommand( + ctx.get_builder(), cq_global_id, program_global_id, blocking); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::EnqueueProgramCommand, cmd.Union()); +} + +void CaptureCreateKernel( + KernelHandle kernel_id, + Program& program, + const std::string& file_name, + const std::variant& core_spec, + const std::variant& config) { + auto& ctx = LightMetalCaptureContext::get(); + + std::shared_ptr kernel = program.get_kernel(kernel_id); + uint32_t kernel_global_id = ctx.add_to_map(kernel.get()); + uint32_t program_global_id = ctx.get_global_id(&program); + log_debug( + tt::LogMetalTrace, + "{}: file_name: {} kernel_global_id: {} (kernel_id: {}) program_global_id: {}", + __FUNCTION__, + file_name, + kernel_global_id, + kernel_id, + program_global_id); + + auto& fbb = ctx.get_builder(); + auto filename_offset = fbb.CreateString(file_name); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto [kernel_config_type, kernel_config_offset] = to_flatbuffer(fbb, config); + + auto cmd = tt::tt_metal::flatbuffer::CreateCreateKernelCommand( + fbb, + kernel_global_id, + program_global_id, + filename_offset, + core_spec_type, + core_spec_offset, + kernel_config_type, + kernel_config_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateKernelCommand, cmd.Union()); +} + +void CaptureSetRuntimeArgsUint32( + const Program& program, + KernelHandle kernel_id, + const std::variant& core_spec, + tt::stl::Span runtime_args) { + auto& ctx = LightMetalCaptureContext::get(); + + std::shared_ptr kernel = program.get_kernel(kernel_id); + uint32_t program_global_id = ctx.get_global_id(&program); + uint32_t kernel_global_id = ctx.get_global_id(kernel.get()); + log_debug( + tt::LogMetalTrace, + "{}(uint32): kernel_global_id: {} program_global_id: {} rt_args: {}", + __FUNCTION__, + kernel_global_id, + program_global_id, + runtime_args.size()); + + auto& fbb = ctx.get_builder(); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto rt_args_offset = fbb.CreateVector(runtime_args.data(), runtime_args.size()); + + auto cmd = tt::tt_metal::flatbuffer::CreateSetRuntimeArgsUint32Command( + fbb, program_global_id, kernel_global_id, core_spec_type, core_spec_offset, rt_args_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::SetRuntimeArgsUint32Command, cmd.Union()); +} + +void CaptureSetRuntimeArgs( + IDevice* device, + const std::shared_ptr& kernel, + const std::variant& core_spec, + const std::shared_ptr& runtime_args) { + auto& ctx = LightMetalCaptureContext::get(); + auto& fbb = ctx.get_builder(); + uint32_t kernel_global_id = ctx.get_global_id(kernel.get()); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto rt_args_offset = to_flatbuffer(fbb, runtime_args); + log_debug( + tt::LogMetalTrace, + "{}(RuntimeArgs): kernel_global_id: {} rt_args_size: {}", + __FUNCTION__, + kernel_global_id, + runtime_args->size()); + + auto cmd = tt::tt_metal::flatbuffer::CreateSetRuntimeArgsCommand( + fbb, kernel_global_id, core_spec_type, core_spec_offset, rt_args_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::SetRuntimeArgsCommand, cmd.Union()); +} + +void CaptureCreateCircularBuffer( + CBHandle& cb_handle, + Program& program, + const std::variant& core_spec, + const CircularBufferConfig& config) { + auto& ctx = LightMetalCaptureContext::get(); + auto& fbb = ctx.get_builder(); + uint32_t cb_global_id = ctx.add_to_map(cb_handle); + uint32_t program_global_id = ctx.get_global_id(&program); + auto [core_spec_type, core_spec_offset] = to_flatbuffer(fbb, core_spec); + auto cb_config_offset = to_flatbuffer(config, fbb); + log_debug( + tt::LogMetalTrace, + "{}: cb_global_id: {} program_global_id: {} ", + __FUNCTION__, + cb_global_id, + program_global_id); + + auto cmd = tt::tt_metal::flatbuffer::CreateCreateCircularBufferCommand( + fbb, cb_global_id, program_global_id, core_spec_type, core_spec_offset, cb_config_offset); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::CreateCircularBufferCommand, cmd.Union()); +} + +void CaptureLightMetalCompare( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* golden_data, + bool is_user_data) { + auto& ctx = LightMetalCaptureContext::get(); + + // We don't want to use shared_ptr to extend lifetime of buffer when adding to global_id map. + Buffer* buffer_ptr = std::holds_alternative>(buffer) + ? std::get>(buffer).get() + : &std::get>(buffer).get(); + + uint32_t cq_global_id = cq.id(); // TODO (kmabee) - consider storing/getting CQ from global map instead. + uint32_t buffer_global_id = ctx.get_global_id(buffer_ptr); + + // Calculate num uint32_t elements in buffer, and convert golden void* to vector + size_t golden_data_len = buffer_ptr->size() / sizeof(uint32_t); + const uint32_t* golden_data_uint32 = static_cast(golden_data); + std::vector golden_data_vector(golden_data_uint32, golden_data_uint32 + golden_data_len); + + log_debug( + tt::LogMetalTrace, + "{}: buffer_global_id: {} is_user_data: {} golden_data_len: {}", + __FUNCTION__, + buffer_global_id, + is_user_data, + golden_data_len); + + // Serialize golden_data into FlatBuffer + auto golden_data_fb = ctx.get_builder().CreateVector(golden_data_vector); + + auto cmd = tt::tt_metal::flatbuffer::CreateLightMetalCompareCommand( + ctx.get_builder(), cq_global_id, buffer_global_id, golden_data_fb, is_user_data); + CaptureCommand(tt::tt_metal::flatbuffer::CommandType::LightMetalCompareCommand, cmd.Union()); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp b/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp new file mode 100644 index 00000000000..3639fd3b90b --- /dev/null +++ b/tt_metal/impl/lightmetal/host_api_capture_helpers.hpp @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "flatbuffers/flatbuffers.h" +#include "lightmetal/lightmetal_capture.hpp" +#include +#include +#include +#include + +namespace tt::tt_metal { + +// Many forward decls and aliases to reduce includes. +class CommandQueue; +struct DataMovementConfig; +struct ComputeConfig; +struct EthernetConfig; + +inline namespace v0 { +class IDevice; +struct BufferConfig; +struct CircularBufferConfig; +using RuntimeArgs = std::vector>; +} // namespace v0 + +////////////////////////////////////////////////////////////// +// TRACE GUARD & LIGHT METAL TRACE MACRO // +////////////////////////////////////////////////////////////// + +// This struct will disable further tracing in current scope, and re-enable +// when scope ends. Prevents recursive tracing of host APIs. +struct TraceScope { + // Provide an inline definition in the header + static inline thread_local int depth = 0; + // Increment depth on entering scope, decrement on exiting + TraceScope() { ++depth; } + ~TraceScope() { --depth; } +}; + +} // namespace tt::tt_metal + +#if defined(TT_ENABLE_LIGHT_METAL_TRACE) && (TT_ENABLE_LIGHT_METAL_TRACE == 1) + +#define LIGHT_METAL_TRACE_FUNCTION_ENTRY() tt::tt_metal::TraceScope __traceScopeGuard + +#define LIGHT_METAL_TRACE_FUNCTION_CALL(capture_func, ...) \ + do { \ + log_trace( \ + tt::LogMetalTrace, \ + "LIGHT_METAL_TRACE_FUNCTION_CALL: {} via {} istracing: {} depth: {}", \ + #capture_func, \ + __FUNCTION__, \ + LightMetalCaptureContext::get().is_tracing(), \ + tt::tt_metal::TraceScope::depth); \ + if (LightMetalCaptureContext::get().is_tracing() && tt::tt_metal::TraceScope::depth == 1) { \ + capture_func(__VA_ARGS__); \ + } \ + } while (0) +#else + +#define LIGHT_METAL_TRACE_FUNCTION_ENTRY() +#define LIGHT_METAL_TRACE_FUNCTION_CALL(capture_func, ...) \ + do { \ + } while (0) + +#endif + +namespace tt::tt_metal { + +// Per Command type capture helper functions +void CaptureReplayTrace(IDevice* device, uint8_t cq_id, uint32_t tid, bool blocking); + +void CaptureEnqueueTrace(CommandQueue& cq, uint32_t tid, bool blocking); + +void CaptureLoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid); + +void CaptureReleaseTrace(IDevice* device, uint32_t tid); + +void CaptureCreateBuffer(const std::shared_ptr& buffer, const InterleavedBufferConfig& config); + +void CaptureDeallocateBuffer(Buffer& buffer); + +void CaptureEnqueueWriteBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + HostDataType src, + bool blocking); + +void CaptureEnqueueReadBuffer( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* dst, + bool blocking); + +void CaptureFinish(CommandQueue& cq, tt::stl::Span sub_device_ids); +void CaptureCreateProgram(Program& program); +void CaptureEnqueueProgram(CommandQueue& cq, Program& program, bool blocking); + +void CaptureCreateKernel( + KernelHandle kernel_id, + Program& program, + const std::string& file_name, + const std::variant& core_spec, + const std::variant& config); + +void CaptureSetRuntimeArgsUint32( + const Program& program, + KernelHandle kernel_id, + const std::variant& core_spec, + tt::stl::Span runtime_args); + +void CaptureSetRuntimeArgs( + IDevice* device, + const std::shared_ptr& kernel, + const std::variant& core_spec, + const std::shared_ptr& runtime_args); + +void CaptureCreateCircularBuffer( + CBHandle& cb_handle, + Program& program, + const std::variant& core_spec, + const CircularBufferConfig& config); + +void CaptureLightMetalCompare( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* golden_data, + bool is_user_data); + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture.cpp b/tt_metal/impl/lightmetal/lightmetal_capture.cpp new file mode 100644 index 00000000000..c1c7d4e4dee --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_capture.cpp @@ -0,0 +1,234 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include "lightmetal/lightmetal_capture.hpp" +#include "flatbuffers/flatbuffers.h" +#include "command_generated.h" +#include "light_metal_binary_generated.h" +#include +#include +#include +#include + +namespace tt::tt_metal { +inline namespace v0 { + +LightMetalCaptureContext::LightMetalCaptureContext() : is_tracing_(false), builder_() {} + +// Singleton instance accessor +LightMetalCaptureContext& LightMetalCaptureContext::get() { + static LightMetalCaptureContext instance; + return instance; +} + +bool LightMetalCaptureContext::is_tracing() const { return is_tracing_; } + +void LightMetalCaptureContext::set_tracing(bool is_tracing) { is_tracing_ = is_tracing; } + +flatbuffers::FlatBufferBuilder& LightMetalCaptureContext::get_builder() { return builder_; } + +std::vector>& LightMetalCaptureContext::get_cmds_vector() { + return cmds_vec_; +} + +void LightMetalCaptureContext::capture_trace_descriptor(const TraceDescriptor& trace_desc, const uint32_t tid) { + trace_descs_vec_.push_back(to_flatbuffer(builder_, trace_desc, tid)); +} + +// Create final flatbuffer binary from the built up data and return to caller as blob. +// If light_metal_binary itself (flatbuffer object) is of interest, could return it instead. +LightMetalBinary LightMetalCaptureContext::create_light_metal_binary() { + auto cmds_vec_fb = builder_.CreateVector(cmds_vec_); + auto sorted_trace_descs = builder_.CreateVectorOfSortedTables(&trace_descs_vec_); + auto light_metal_binary = + tt::tt_metal::flatbuffer::CreateLightMetalBinary(builder_, cmds_vec_fb, sorted_trace_descs); + builder_.Finish(light_metal_binary); + + const uint8_t* buffer_ptr = builder_.GetBufferPointer(); + size_t buffer_size = builder_.GetSize(); + + std::vector binary_data(buffer_ptr, buffer_ptr + buffer_size); + return LightMetalBinary(std::move(binary_data)); +} + +// Reset some internal state, and ensure tracing isn't active. Should only be called at start of tracing. +void LightMetalCaptureContext::reset() { + TT_ASSERT(!is_tracing_, "Cannot reset light metal capture context while tracing is enabled."); + builder_.Clear(); + next_global_id_ = 0; + cmds_vec_.clear(); + trace_descs_vec_.clear(); + buffer_to_global_id_map_.clear(); + program_to_global_id_map_.clear(); + kernel_to_global_id_map_.clear(); + cb_handle_to_global_id_map_.clear(); +} + +//////////////////////////////////////////// +// Object Map Public Accessors // +//////////////////////////////////////////// + +bool LightMetalCaptureContext::is_in_map(const Buffer* obj) { + return buffer_to_global_id_map_.find(obj) != buffer_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const Buffer* obj) { + if (is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Buffer already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + buffer_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const Buffer* obj) { + if (!is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Buffer not found in global_id map."); + } + buffer_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::get_global_id(const Buffer* obj) { + auto it = buffer_to_global_id_map_.find(obj); + if (it != buffer_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("Buffer not found in global_id global_id map"); + } +} + +bool LightMetalCaptureContext::is_in_map(const Program* obj) { + return program_to_global_id_map_.find(obj) != program_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const Program* obj) { + if (is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Program already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + program_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const Program* obj) { + if (!is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Program not found in global_id map."); + } + program_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::get_global_id(const Program* obj) { + auto it = program_to_global_id_map_.find(obj); + if (it != program_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("Program not found in global_id map."); + } +} + +bool LightMetalCaptureContext::is_in_map(const Kernel* obj) { + return kernel_to_global_id_map_.find(obj) != kernel_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const Kernel* obj) { + if (is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Kernel already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + kernel_to_global_id_map_[obj] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const Kernel* obj) { + if (!is_in_map(obj)) { + log_warning(tt::LogMetalTrace, "Kernel not found in global_id map."); + } + kernel_to_global_id_map_.erase(obj); +} + +uint32_t LightMetalCaptureContext::get_global_id(const Kernel* obj) { + auto it = kernel_to_global_id_map_.find(obj); + if (it != kernel_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("Kernel not found in global_id map."); + } +} + +bool LightMetalCaptureContext::is_in_map(const CBHandle handle) { + return cb_handle_to_global_id_map_.find(handle) != cb_handle_to_global_id_map_.end(); +} + +uint32_t LightMetalCaptureContext::add_to_map(const CBHandle handle) { + if (is_in_map(handle)) { + log_warning(tt::LogMetalTrace, "CBHandle already exists in global_id map."); + } + uint32_t global_id = next_global_id_++; + cb_handle_to_global_id_map_[handle] = global_id; + return global_id; +} + +void LightMetalCaptureContext::remove_from_map(const CBHandle handle) { + if (!is_in_map(handle)) { + log_warning(tt::LogMetalTrace, "CBHandle not found in global_id map."); + } + cb_handle_to_global_id_map_.erase(handle); +} + +uint32_t LightMetalCaptureContext::get_global_id(const CBHandle handle) { + auto it = cb_handle_to_global_id_map_.find(handle); + if (it != cb_handle_to_global_id_map_.end()) { + return it->second; + } else { + TT_THROW("CBHandle not found in global_id map."); + } +} + +//////////////////////////////////////////// +// Non-Class Helper Functions // +//////////////////////////////////////////// + +// Serialize tt-metal traceDescriptor and trace_id to flatbuffer format. +TraceDescriptorByTraceIdOffset to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const TraceDescriptor& trace_desc, const uint32_t trace_id) { + // Serialize the trace_data vector + auto trace_data_offset = builder.CreateVector(trace_desc.data); + + // Serialize the sub_device_descriptors (map) + std::vector> + sub_device_descriptor_offsets; + for (const auto& [sub_device_id, descriptor] : trace_desc.descriptors) { + auto descriptor_offset = tt::tt_metal::flatbuffer::CreateTraceDescriptorMetaData( + builder, + descriptor.num_completion_worker_cores, + descriptor.num_traced_programs_needing_go_signal_multicast, + descriptor.num_traced_programs_needing_go_signal_unicast); + auto mapping_offset = tt::tt_metal::flatbuffer::CreateSubDeviceDescriptorMapping( + builder, + sub_device_id.to_index(), // No need for static_cast; directly use uint8_t + descriptor_offset); + sub_device_descriptor_offsets.push_back(mapping_offset); + } + auto sub_device_descriptors_offset = builder.CreateVector(sub_device_descriptor_offsets); + + // Serialize the sub_device_ids vector + std::vector sub_device_ids_converted; + sub_device_ids_converted.reserve(trace_desc.sub_device_ids.size()); + for (const auto& sub_device_id : trace_desc.sub_device_ids) { + sub_device_ids_converted.push_back(sub_device_id.to_index()); + } + auto sub_device_ids_offset = builder.CreateVector(sub_device_ids_converted); + + // Create the TraceDescriptor + auto trace_descriptor_offset = tt::tt_metal::flatbuffer::CreateTraceDescriptor( + builder, trace_data_offset, sub_device_descriptors_offset, sub_device_ids_offset); + + // Create the TraceDescriptorByTraceId + return tt::tt_metal::flatbuffer::CreateTraceDescriptorByTraceId(builder, trace_id, trace_descriptor_offset); +} + +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture.hpp b/tt_metal/impl/lightmetal/lightmetal_capture.hpp new file mode 100644 index 00000000000..3712e666108 --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_capture.hpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include + +// Forward decl for command_generated.h +namespace tt::tt_metal::flatbuffer { +class Command; +} + +// Forward decl for light_metal_binary_generated.h +namespace tt::tt_metal::flatbuffer { +struct TraceDescriptor; +struct TraceDescriptorByTraceId; +} // namespace tt::tt_metal::flatbuffer + +// Forward decl for trace_buffer.hpp +namespace tt::tt_metal { +class TraceDescriptor; +} + +namespace tt::tt_metal { +inline namespace v0 { + +class Buffer; +class Program; +class Kernel; +using CBHandle = uintptr_t; +using TraceDescriptorByTraceIdOffset = flatbuffers::Offset; + +class LightMetalCaptureContext { +public: + static LightMetalCaptureContext& get(); + + bool is_tracing() const; + void set_tracing(bool tracing); + + flatbuffers::FlatBufferBuilder& get_builder(); + std::vector>& get_cmds_vector(); + void capture_trace_descriptor(const TraceDescriptor& trace_desc, uint32_t tid); + LightMetalBinary create_light_metal_binary(); + void reset(); + + // Object Map Public Accessors + bool is_in_map(const Buffer* obj); + uint32_t add_to_map(const Buffer* obj); + void remove_from_map(const Buffer* obj); + uint32_t get_global_id(const Buffer* obj); + bool is_in_map(const Program* obj); + uint32_t add_to_map(const Program* obj); + void remove_from_map(const Program* obj); + uint32_t get_global_id(const Program* obj); + bool is_in_map(const Kernel* obj); + uint32_t add_to_map(const Kernel* obj); + void remove_from_map(const Kernel* obj); + uint32_t get_global_id(const Kernel* obj); + bool is_in_map(const CBHandle handle); + uint32_t add_to_map(const CBHandle handle); + void remove_from_map(const CBHandle handle); + uint32_t get_global_id(const CBHandle handle); + +private: + LightMetalCaptureContext(); // Private constructor + + bool is_tracing_ = false; + flatbuffers::FlatBufferBuilder builder_; + std::vector> cmds_vec_; + std::vector trace_descs_vec_; + + // Object maps for associating each object with a global_id + uint32_t next_global_id_ = 0; // Shared across all object types. + std::unordered_map buffer_to_global_id_map_; + std::unordered_map program_to_global_id_map_; + std::unordered_map kernel_to_global_id_map_; + std::unordered_map cb_handle_to_global_id_map_; + // TODO (kmabee) - consider adding map for CommandQueue object. + + LightMetalCaptureContext(const LightMetalCaptureContext&) = delete; + LightMetalCaptureContext& operator=(const LightMetalCaptureContext&) = delete; +}; + +TraceDescriptorByTraceIdOffset to_flatbuffer( + flatbuffers::FlatBufferBuilder& builder, const TraceDescriptor& trace_desc, uint32_t trace_id); + +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_capture_utils.cpp b/tt_metal/impl/lightmetal/lightmetal_capture_utils.cpp new file mode 100644 index 00000000000..d33250777b1 --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_capture_utils.cpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "lightmetal/host_api_capture_helpers.hpp" +#include +#include + +namespace tt::tt_metal { + +void LightMetalCompareToCapture( + CommandQueue& cq, const std::variant, std::shared_ptr>& buffer, void* dst) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + + // If dst ptr is not provided, just allocate temp space for rd return capture/usage. + std::vector rd_data_tmp; + if (!dst) { + size_t buffer_size = std::holds_alternative>(buffer) + ? std::get>(buffer).get().size() + : std::get>(buffer)->size(); + rd_data_tmp.resize(buffer_size / sizeof(uint32_t)); + dst = rd_data_tmp.data(); + } + + EnqueueReadBuffer(cq, buffer, dst, true); // Blocking read to get golden value. + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureLightMetalCompare, cq, buffer, dst, false); +} + +void LightMetalCompareToGolden( + CommandQueue& cq, + const std::variant, std::shared_ptr>& buffer, + void* golden_data) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureLightMetalCompare, cq, buffer, golden_data, true); +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_replay.cpp b/tt_metal/impl/lightmetal/lightmetal_replay.cpp new file mode 100644 index 00000000000..2971f438fa4 --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_replay.cpp @@ -0,0 +1,661 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "light_metal_binary_generated.h" +#include "command_generated.h" +#include +#include + +#include +#include +#include +#include +#include "lightmetal/lightmetal_replay.hpp" +#include "flatbuffer/base_types_from_flatbuffer.hpp" +#include "flatbuffer/program_types_from_flatbuffer.hpp" +#include "flatbuffer/buffer_types_from_flatbuffer.hpp" + +namespace tt::tt_metal { + +////////////////////////////////////// +// Helper Functions // +////////////////////////////////////// + +TraceDescriptor from_flatbuffer(const tt::tt_metal::flatbuffer::TraceDescriptor* fb_desc) { + if (!fb_desc) { + std::cerr << "TraceDescriptor is null." << std::endl; + return {}; + } + + TraceDescriptor trace_desc; + + // Deserialize trace_data + if (auto trace_data_fb = fb_desc->trace_data()) { + trace_desc.data.assign(trace_data_fb->begin(), trace_data_fb->end()); + } + + // Deserialize sub_device_descriptors + if (auto sub_device_descriptors_fb = fb_desc->sub_device_descriptors()) { + for (const auto* mapping : *sub_device_descriptors_fb) { + if (mapping) { + TraceDescriptor::Descriptor descriptor; + descriptor.num_completion_worker_cores = mapping->descriptor()->num_completion_worker_cores(); + descriptor.num_traced_programs_needing_go_signal_multicast = + mapping->descriptor()->num_traced_programs_needing_go_signal_multicast(); + descriptor.num_traced_programs_needing_go_signal_unicast = + mapping->descriptor()->num_traced_programs_needing_go_signal_unicast(); + + // Add the descriptor to the map + trace_desc.descriptors[SubDeviceId{mapping->sub_device_id()}] = descriptor; + } + } + } + + // Deserialize sub_device_ids + if (auto sub_device_ids_fb = fb_desc->sub_device_ids()) { + for (const auto id : *sub_device_ids_fb) { + trace_desc.sub_device_ids.emplace_back(SubDeviceId{id}); + } + } + + return trace_desc; +} + +// Needs access to BufferMap, so part of LightMetalReplay class +std::shared_ptr LightMetalReplay::rt_args_from_flatbuffer( + const FlatbufferRuntimeArgVector flatbuffer_args) { + auto runtime_args = std::make_shared(); + + for (const auto& flatbuffer_arg : *flatbuffer_args) { + const auto* runtime_arg = flatbuffer_arg; + TT_FATAL(runtime_arg, "Null RuntimeArg in FlatBuffer vector"); + + // Determine the type of the RuntimeArg + switch (runtime_arg->value_type()) { + case tt::tt_metal::flatbuffer::RuntimeArgValue::UInt32Value: { + // Extract UInt32Value + const auto* uint32_value = runtime_arg->value_as_UInt32Value(); + TT_FATAL(uint32_value, "Failed to read UInt32Value"); + runtime_args->emplace_back(uint32_value->value()); + break; + } + case tt::tt_metal::flatbuffer::RuntimeArgValue::BufferGlobalId: { + // Extract BufferGlobalId + const auto* buffer_global_id = runtime_arg->value_as_BufferGlobalId(); + TT_FATAL(buffer_global_id, "Failed to read BufferGlobalId"); + uint32_t global_id = buffer_global_id->id(); + auto buffer = get_buffer_from_map(global_id); + TT_FATAL(buffer, "Buffer w/ global_id: {} not previously created", global_id); + runtime_args->emplace_back(buffer.get()); + break; + } + case tt::tt_metal::flatbuffer::RuntimeArgValue::NONE: { + TT_THROW("Unknown RuntimeArgValue type NONE in FlatBuffer"); + } + } + } + + return runtime_args; +} + +////////////////////////////////////// +// LightMetalReplay Class // +////////////////////////////////////// + +LightMetalReplay::LightMetalReplay(LightMetalBinary&& binary) : binary_(std::move(binary)), fb_binary_(nullptr) { + if (binary_.is_empty()) { + log_warning(tt::LogMetalTrace, "Empty LightMetalBinary provided to LightMetalReplay."); + } + + show_reads_ = parse_env("TT_LIGHT_METAL_SHOW_READS", false); + disable_checking_ = parse_env("TT_LIGHT_METAL_DISABLE_CHECKING", false); + fb_binary_ = parse_flatbuffer_binary(); // Parse and store the FlatBuffer binary +} + +const tt::tt_metal::flatbuffer::LightMetalBinary* LightMetalReplay::parse_flatbuffer_binary() { + try { + const uint8_t* data_ptr = binary_.get_data().data(); + size_t size = binary_.get_data().size(); + + // Verify the FlatBuffer data. + flatbuffers::Verifier verifier(data_ptr, size); + if (!tt::tt_metal::flatbuffer::VerifyLightMetalBinaryBuffer(verifier)) { + std::cerr << "Failed to verify FlatBuffer data." << std::endl; + return nullptr; + } + + // Parse and return the FlatBuffer object. + return tt::tt_metal::flatbuffer::GetLightMetalBinary(data_ptr); + } catch (const std::exception& e) { + std::cerr << "Exception while parsing FlatBuffer binary: " << e.what() << std::endl; + return nullptr; + } +} + +// Return a TraceDescriptor for a given trace_id from the FlatBuffer binary. +std::optional LightMetalReplay::get_trace_by_id(uint32_t target_trace_id) { + if (const auto* trace_descriptors = fb_binary_ ? fb_binary_->trace_descriptors() : nullptr) { + if (const auto* fb_trace_desc_by_id = trace_descriptors->LookupByKey(target_trace_id)) { + if (const auto* fb_desc = fb_trace_desc_by_id->desc()) { + return from_flatbuffer(fb_desc); + } + } + } + + std::cerr << "Failed to find trace_id: " << target_trace_id << " in binary." << std::endl; + return std::nullopt; +} + +////////////////////////////////////// +// Object Map Public Accessors // +////////////////////////////////////// + +void LightMetalReplay::add_buffer_to_map(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Buffer>& buffer) { + if (buffer_map_.find(global_id) != buffer_map_.end()) { + log_warning(tt::LogMetalTrace, "Buffer with global_id: {} already exists in map.", global_id); + } + buffer_map_[global_id] = buffer; // Shared ownership +} + +std::shared_ptr<::tt::tt_metal::Buffer> LightMetalReplay::get_buffer_from_map(uint32_t global_id) const { + auto it = buffer_map_.find(global_id); + return it != buffer_map_.end() ? it->second : nullptr; +} + +void LightMetalReplay::remove_bufer_from_map(uint32_t global_id) { buffer_map_.erase(global_id); } + +void LightMetalReplay::add_program_to_map(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Program>& program) { + if (program_map_.find(global_id) != program_map_.end()) { + log_warning(tt::LogMetalTrace, "Program with global_id: {} already exists in map.", global_id); + } + program_map_[global_id] = program; // Shared ownership +} + +std::shared_ptr<::tt::tt_metal::Program> LightMetalReplay::get_program_from_map(uint32_t global_id) const { + auto it = program_map_.find(global_id); + return it != program_map_.end() ? it->second : nullptr; +} + +void LightMetalReplay::remove_program_from_map(uint32_t global_id) { program_map_.erase(global_id); } + +void LightMetalReplay::add_kernel_handle_to_map(uint32_t global_id, ::tt::tt_metal::KernelHandle kernel_id) { + if (kernel_handle_map_.find(global_id) != kernel_handle_map_.end()) { + log_warning(tt::LogMetalTrace, "KernelHandle with global_id: {} already exists in map.", global_id); + } + kernel_handle_map_[global_id] = kernel_id; // Shared ownership +} + +::tt::tt_metal::KernelHandle LightMetalReplay::get_kernel_handle_from_map(uint32_t global_id) const { + auto it = kernel_handle_map_.find(global_id); + return it != kernel_handle_map_.end() ? it->second : UINT32_MAX; +} + +void LightMetalReplay::remove_kernel_handle_from_map(uint32_t global_id) { kernel_handle_map_.erase(global_id); } + +void LightMetalReplay::add_kernel_to_map(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Kernel>& kernel) { + if (kernel_map_.find(global_id) != kernel_map_.end()) { + log_warning(tt::LogMetalTrace, "Kernel with global_id: {} already exists in map.", global_id); + } + kernel_map_[global_id] = kernel; // Shared ownership +} + +std::shared_ptr<::tt::tt_metal::Kernel> LightMetalReplay::get_kernel_from_map(uint32_t global_id) const { + auto it = kernel_map_.find(global_id); + return it != kernel_map_.end() ? it->second : nullptr; +} + +void LightMetalReplay::remove_kernel_from_map(uint32_t global_id) { kernel_map_.erase(global_id); } + +void LightMetalReplay::add_cb_handle_to_map(uint32_t global_id, ::tt::tt_metal::CBHandle cb_handle) { + if (cb_handle_map_.find(global_id) != cb_handle_map_.end()) { + log_warning(tt::LogMetalTrace, "CBHandle with global_id: {} already exists in map.", global_id); + } + cb_handle_map_[global_id] = cb_handle; // Shared ownership +} + +::tt::tt_metal::CBHandle LightMetalReplay::get_cb_handle_from_map(uint32_t global_id) const { + auto it = cb_handle_map_.find(global_id); + return it != cb_handle_map_.end() ? it->second : UINT32_MAX; +} + +void LightMetalReplay::remove_cb_handle_from_map(uint32_t global_id) { cb_handle_map_.erase(global_id); } + +////////////////////////////////////// +// Device Setup/Teardown // +////////////////////////////////////// + +// TODO (kmabee) - Hardcode for now, eventually capture/replay "systemdesc" from binary. +void LightMetalReplay::setup_devices() { + log_debug(tt::LogMetalTrace, "LightMetalReplay(setup_devices) - Using hardcoded CreateDevices() as temp hack."); + const size_t trace_region_size = 4096; // Default is 0 + const int device_id = 0; + const auto dispatch_core_type = tt_metal::DispatchCoreType::WORKER; + const chip_id_t mmio_device_id = 0; + auto devices_map = tt::tt_metal::detail::CreateDevices( + {mmio_device_id}, 1, DEFAULT_L1_SMALL_SIZE, trace_region_size, dispatch_core_type); + this->device_ = devices_map.at(mmio_device_id); +} + +// TODO (kmabee) - Hardcode for now, eventually capture/replay "systemdesc" from binary or let user call. +void LightMetalReplay::close_devices() { CloseDevice(this->device_); } + +////////////////////////////////////// +// Executor // +////////////////////////////////////// + +// execute a command by dispatching to appropriate handler based on type. +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::Command* command) { + switch (command->cmd_type()) { + case ::tt::tt_metal::flatbuffer::CommandType::EnqueueTraceCommand: { + execute(command->cmd_as_EnqueueTraceCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::ReplayTraceCommand: { + execute(command->cmd_as_ReplayTraceCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::LoadTraceCommand: { + execute(command->cmd_as_LoadTraceCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::ReleaseTraceCommand: { + execute(command->cmd_as_ReleaseTraceCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::CreateBufferCommand: { + execute(command->cmd_as_CreateBufferCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::DeallocateBufferCommand: { + execute(command->cmd_as_DeallocateBufferCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::EnqueueWriteBufferCommand: { + execute(command->cmd_as_EnqueueWriteBufferCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::EnqueueReadBufferCommand: { + execute(command->cmd_as_EnqueueReadBufferCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::FinishCommand: { + execute(command->cmd_as_FinishCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::CreateProgramCommand: { + execute(command->cmd_as_CreateProgramCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::EnqueueProgramCommand: { + execute(command->cmd_as_EnqueueProgramCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::CreateKernelCommand: { + execute(command->cmd_as_CreateKernelCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::SetRuntimeArgsUint32Command: { + execute(command->cmd_as_SetRuntimeArgsUint32Command()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::SetRuntimeArgsCommand: { + execute(command->cmd_as_SetRuntimeArgsCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::CreateCircularBufferCommand: { + execute(command->cmd_as_CreateCircularBufferCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::LightMetalCompareCommand: { + execute(command->cmd_as_LightMetalCompareCommand()); + break; + } + case ::tt::tt_metal::flatbuffer::CommandType::NONE: + TT_THROW("LightMetalReplay execute encountered unsupported cmd type NONE"); + break; + } +} + +// Per API command handlers. +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::EnqueueTraceCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueTrace) cq_id: {} tid: {} blocking: {}", + cmd->cq_id(), + cmd->tid(), + cmd->blocking()); + CommandQueue& cq = this->device_->command_queue(cmd->cq_id()); + EnqueueTrace(cq, cmd->tid(), cmd->blocking()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::ReplayTraceCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(ReplayTrace) cq_id: {} tid: {} blocking: {}", + cmd->cq_id(), + cmd->tid(), + cmd->blocking()); + ReplayTrace(this->device_, cmd->cq_id(), cmd->tid(), cmd->blocking()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::LoadTraceCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(LoadTrace) cq_id: {} tid: {}", cmd->cq_id(), cmd->tid()); + // Get the trace descriptor from flatbuffer and load it to device. + auto trace_desc = get_trace_by_id(cmd->tid()); + LoadTrace(this->device_, cmd->cq_id(), cmd->tid(), trace_desc.value()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::ReleaseTraceCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(ReleaseTrace) tid: {}", cmd->tid()); + ReleaseTrace(this->device_, cmd->tid()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::CreateBufferCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(CreateBuffer) global_id: {} size: {} page_size: {} layout: {} buffer_type: {}", + cmd->global_id(), + cmd->config()->size(), + cmd->config()->page_size(), + EnumNameTensorMemoryLayout(cmd->config()->buffer_layout()), + EnumNameBufferType(cmd->config()->buffer_type())); + + switch (cmd->config()->buffer_layout()) { + case tt::tt_metal::flatbuffer::TensorMemoryLayout::Interleaved: { + tt::tt_metal::InterleavedBufferConfig config{ + .device = this->device_, + .size = cmd->config()->size(), + .page_size = cmd->config()->page_size(), + .buffer_type = from_flatbuffer(cmd->config()->buffer_type())}; + + auto buffer = CreateBuffer(config); + add_buffer_to_map(cmd->global_id(), buffer); + break; + } + default: + // TODO (kmabee) - Add support for other buffer_layouts. + TT_THROW( + "Unsupported buffer_layout: {}", + std::string(EnumNameTensorMemoryLayout(cmd->config()->buffer_layout()))); + } +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::DeallocateBufferCommand* cmd) { + auto buffer = get_buffer_from_map(cmd->global_id()); + TT_FATAL( + buffer, + "Attempted to DeallocateBuffer() buffer w/ global_id: {} that was not previously created.", + cmd->global_id()); + + log_debug(tt::LogMetalTrace, "LightMetalReplay(DeallocateBuffer) global_id: {}", cmd->global_id()); + DeallocateBuffer(*buffer); // Buffer& expected. + remove_bufer_from_map(cmd->global_id()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::EnqueueWriteBufferCommand* cmd) { + auto buffer = get_buffer_from_map(cmd->buffer_global_id()); + TT_FATAL( + buffer, + "Attempted to EnqueueWriteBuffer() buffer w/ global_id: {} that was not previously created.", + cmd->buffer_global_id()); + + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueWriteBuffer) cq_global_id: {} buffer_global_id: {} addr: 0x{:x}", + cmd->cq_global_id(), + cmd->buffer_global_id(), + buffer->address()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + EnqueueWriteBuffer(cq, buffer, cmd->src()->data(), cmd->blocking()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::EnqueueReadBufferCommand* cmd) { + auto buffer = get_buffer_from_map(cmd->buffer_global_id()); + TT_FATAL( + buffer, + "Attempted to EnqueueReadBuffer() buffer w/ global_id: {} that was not previously created.", + cmd->buffer_global_id()); + + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueReadBuffer) cq_global_id: {} buffer_global_id: {} addr: 0x{:x} buf_size: {}", + cmd->cq_global_id(), + cmd->buffer_global_id(), + buffer->address(), + buffer->size()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + std::vector readback_data(buffer->size() / sizeof(uint32_t), 0); + EnqueueReadBuffer(cq, buffer, readback_data.data(), cmd->blocking()); + + // TODO (kmabee) - TBD what to do with readback data. For now, optionally print. + // One idea is to store in map by global_read_id that caller can access. + if (show_reads_) { + for (size_t i = 0; i < readback_data.size(); i++) { + log_info(tt::LogMetalTrace, " rd_data i: {:3d} => data: {} ({:x})", i, readback_data[i], readback_data[i]); + } + } +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::FinishCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(Finish) cq_global_id: {}", cmd->cq_global_id()); + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + auto sub_device_ids = from_flatbuffer(cmd->sub_device_ids()); + Finish(cq, sub_device_ids); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::CreateProgramCommand* cmd) { + log_debug(tt::LogMetalTrace, "LightMetalReplay(CreateProgram) global_id: {} ", cmd->global_id()); + auto program = CreateProgram(); + add_program_to_map(cmd->global_id(), std::make_shared(std::move(program))); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::EnqueueProgramCommand* cmd) { + auto program = get_program_from_map(cmd->program_global_id()); + TT_FATAL( + program, + "Attempted to EnqueueProgram() program w/ global_id: {} that was not previously created.", + cmd->program_global_id()); + + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(EnqueueProgram) program_global_id: {} cq_global_id: {}", + cmd->program_global_id(), + cmd->cq_global_id()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + EnqueueProgram(cq, *program, cmd->blocking()); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::CreateKernelCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(CreateKernel) global_id: {} program_global_id: {}", + cmd->global_id(), + cmd->program_global_id()); + auto program = get_program_from_map(cmd->program_global_id()); + TT_FATAL( + program, + "Attempted to CreateKernel() using a program w/ global_id: {} that was not previously created.", + cmd->program_global_id()); + + auto core_spec = core_spec_from_flatbuffer(cmd); + auto kernel_config = kernel_config_from_flatbuffer(cmd); + auto kernel_id = CreateKernel(*program, cmd->file_name()->c_str(), core_spec, kernel_config); + add_kernel_handle_to_map(cmd->global_id(), kernel_id); + // Some APIs use Kernel, so convert to and store Kernel. + std::shared_ptr kernel = program->get_kernel(kernel_id); + add_kernel_to_map(cmd->global_id(), kernel); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::SetRuntimeArgsUint32Command* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(SetRuntimeArgs). program_global_id: {} kernel_global_id: {}", + cmd->program_global_id(), + cmd->kernel_global_id()); + auto program = get_program_from_map(cmd->program_global_id()); + auto kernel_id = get_kernel_handle_from_map(cmd->kernel_global_id()); + TT_FATAL( + program, + "Attempted to SetRuntimeArgs() using a program w/ global_id: {} that was not previously created.", + cmd->program_global_id()); + TT_FATAL( + kernel_id != UINT32_MAX, + "Attempted to SetRuntimeArgs() using a kernel w/ global_id: {} that was not previously created.", + cmd->kernel_global_id()); + + // API expects a span so create from flatbuffer vector. + stl::Span args_span(cmd->args()->data(), cmd->args()->size()); + auto core_spec = core_spec_from_flatbuffer(cmd); + SetRuntimeArgs(*program, kernel_id, core_spec, args_span); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::SetRuntimeArgsCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(SetRuntimeArgs). kernel_global_id: {} rt_args_size: {}", + cmd->kernel_global_id(), + cmd->args()->size()); + auto core_spec = core_spec_from_flatbuffer(cmd); + auto runtime_args = rt_args_from_flatbuffer(cmd->args()); + auto kernel = get_kernel_from_map(cmd->kernel_global_id()); + TT_FATAL( + kernel, + "Attempted to SetRuntimeArgs() using a Kernel w/ global_id: {} that was not previously created.", + cmd->kernel_global_id()); + SetRuntimeArgs(this->device_, kernel, core_spec, runtime_args); +} + +void LightMetalReplay::execute(const tt::tt_metal::flatbuffer::CreateCircularBufferCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(CreateCircularBuffer) global_id: {} program_global_id: {}", + cmd->global_id(), + cmd->program_global_id()); + auto program = get_program_from_map(cmd->program_global_id()); + TT_FATAL( + program, + "Attempted to CreateCircularBuffer() using a Program w/ global_id: {} that was not previously created.", + cmd->program_global_id()); + auto core_spec = core_spec_from_flatbuffer(cmd); + + // Convert global_id to optional Shadow Buffer here to keep from_flatbuffer standalone function. + ::tt::tt_metal::Buffer* shadow_global_buffer = nullptr; + auto shadow_buf_global_id = cmd->config()->shadow_buf_global_id(); + + if (shadow_buf_global_id != 0) { + auto shadow_buf = get_buffer_from_map(shadow_buf_global_id); + TT_FATAL( + shadow_buf, + "Attempted to CreateCircularBuffer() using a shadow Buffer w/ global_id: {} that was not previously " + "created.", + shadow_buf_global_id); + shadow_global_buffer = shadow_buf.get(); // Set the raw pointer + } + + auto config = from_flatbuffer(cmd->config(), shadow_global_buffer); + auto cb_handle = CreateCircularBuffer(*program, core_spec, config); + add_cb_handle_to_map(cmd->global_id(), cb_handle); +} + +// Verification command to compare readback of a buffer with golden from either capture or user expected values. +void LightMetalReplay::execute(const ::tt::tt_metal::flatbuffer::LightMetalCompareCommand* cmd) { + log_debug( + tt::LogMetalTrace, + "LightMetalReplay(LightMetalCompare) cq_global_id: {} buffer_global_id: {} is_user_data: {}", + cmd->cq_global_id(), + cmd->buffer_global_id(), + cmd->is_user_data()); + + auto buffer = get_buffer_from_map(cmd->buffer_global_id()); + TT_FATAL( + buffer, + "Attempted to run LightMetalCompareCommand using a Buffer w/ global_id: {} that was not previously created.", + cmd->buffer_global_id()); + + // TODO (kmabee) - consider storing/getting CQ from global map instead. + CommandQueue& cq = this->device_->command_queue(cmd->cq_global_id()); + std::vector rd_data(buffer->size() / sizeof(uint32_t), 0); + EnqueueReadBuffer(cq, buffer, rd_data.data(), true); + + if (disable_checking_) { + log_debug( + tt::LogMetalTrace, "Skipping LightMetalCompareCommand for buffer_global_id: {}.", cmd->buffer_global_id()); + } else { + if (rd_data.size() != cmd->golden_data()->size()) { + TT_THROW( + "Readback data size: {} does not match golden data size: {}", + rd_data.size(), + cmd->golden_data()->size()); + } + + // Optional debug to show verbose comparison + if (show_reads_) { + for (size_t i = 0; i < rd_data.size(); i++) { + bool match = rd_data[i] == cmd->golden_data()->Get(i); + log_info( + tt::LogMetalTrace, + "LightMetalCompare i: {:3d} match: {} RdData: {:x} Golden: {:x}", + i, + match, + rd_data[i], + cmd->golden_data()->Get(i)); + } + } + + // Do simple equality comparison between two vectors + if (!std::equal(rd_data.begin(), rd_data.end(), cmd->golden_data()->begin())) { + TT_THROW("Golden vs rd_data mismatch for buffer_global_id: {}", cmd->buffer_global_id()); + } + } +} + +// Main entry point to execute a light metal binary blob, return true if pass. +bool LightMetalReplay::execute_binary() { + if (!fb_binary_) { + std::cerr << "Cannot Replay empty/uninitialized Light Metal Binary." << std::endl; + return false; + } + + try { + const auto* trace_descs = fb_binary_->trace_descriptors(); + const auto* commands = fb_binary_->commands(); + if (!commands) { + std::cerr << "Nothing to run, no commands in binary." << std::endl; + return false; + } + + setup_devices(); + log_info( + tt::LogMetalTrace, + "Running LightMetal Binary with {} cmds, {} traces.", + commands->size(), + trace_descs->size()); + + // Just loop over all commands, and execute. This is purposely kept simple for prototyping v0. + // TODO (kmabee) - should expand to cover, multiple devices, cqs, etc. + uint32_t idx = 1; + for (const auto* cmd : *commands) { + auto str_name = std::string(EnumNameCommandType(cmd->cmd_type())); + log_trace(tt::LogMetalTrace, "Executing Binary CMD {}/{} (Type: {})", idx++, commands->size(), str_name); + execute(cmd); + } + + close_devices(); + + return true; + } catch (const std::exception& e) { + close_devices(); + log_fatal(e.what()); + return false; + } +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/lightmetal/lightmetal_replay.hpp b/tt_metal/impl/lightmetal/lightmetal_replay.hpp new file mode 100644 index 00000000000..a2c96ecdbe8 --- /dev/null +++ b/tt_metal/impl/lightmetal/lightmetal_replay.hpp @@ -0,0 +1,133 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include +#include +#include + +#include + +// Forward decl for trace_buffer.hpp +namespace tt::tt_metal { +class TraceDescriptor; +} + +// Forward decl for command_generated.h / light_metal_binary_generated.h +namespace tt::tt_metal::flatbuffer { +struct Command; +struct ReplayTraceCommand; +struct EnqueueTraceCommand; +struct LoadTraceCommand; +struct ReleaseTraceCommand; +struct CreateBufferCommand; +struct DeallocateBufferCommand; +struct EnqueueWriteBufferCommand; +struct EnqueueReadBufferCommand; +struct FinishCommand; +struct CreateProgramCommand; +struct EnqueueProgramCommand; +struct CreateKernelCommand; +struct SetRuntimeArgsUint32Command; +struct SetRuntimeArgsCommand; +struct CreateCircularBufferCommand; +struct LightMetalCompareCommand; +struct RuntimeArg; + +struct TraceDescriptor; +struct TraceDescriptorByTraceId; +struct LightMetalBinary; +} // namespace tt::tt_metal::flatbuffer + +using FlatbufferRuntimeArgVector = + const flatbuffers::Vector>*; +using RuntimeArgs = std::vector>; + +namespace tt::tt_metal { +inline namespace v0 { + +class LightMetalReplay { +public: + // Constructor that initializes the class with a binary blob and transfers ownership of the blob. + explicit LightMetalReplay(LightMetalBinary&& binary); + + // Execute the stored LightMetal binary by looping over all commands, and execting them. + // Returns true if passed. Currently has no side-effects/artifacts returned to user, + // may change in the future. + bool execute_binary(); + +private: + // Executor functions for all traced host API calls (commands) + void execute(const tt::tt_metal::flatbuffer::Command* command); + void execute(const tt::tt_metal::flatbuffer::EnqueueTraceCommand* command); + void execute(const tt::tt_metal::flatbuffer::ReplayTraceCommand* command); + void execute(const tt::tt_metal::flatbuffer::LoadTraceCommand* command); + void execute(const tt::tt_metal::flatbuffer::ReleaseTraceCommand* command); + void execute(const tt::tt_metal::flatbuffer::CreateBufferCommand* command); + void execute(const tt::tt_metal::flatbuffer::DeallocateBufferCommand* command); + void execute(const tt::tt_metal::flatbuffer::EnqueueWriteBufferCommand* command); + void execute(const tt::tt_metal::flatbuffer::EnqueueReadBufferCommand* command); + void execute(const tt::tt_metal::flatbuffer::FinishCommand* command); + void execute(const tt::tt_metal::flatbuffer::CreateProgramCommand* command); + void execute(const tt::tt_metal::flatbuffer::EnqueueProgramCommand* command); + void execute(const tt::tt_metal::flatbuffer::CreateKernelCommand* command); + void execute(const tt::tt_metal::flatbuffer::SetRuntimeArgsUint32Command* command); + void execute(const tt::tt_metal::flatbuffer::SetRuntimeArgsCommand* command); + void execute(const tt::tt_metal::flatbuffer::CreateCircularBufferCommand* command); + void execute(const tt::tt_metal::flatbuffer::LightMetalCompareCommand* command); + + // Object maps public accessors + void add_buffer_to_map(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Buffer>& buffer); + std::shared_ptr<::tt::tt_metal::Buffer> get_buffer_from_map(uint32_t global_id) const; + void remove_bufer_from_map(uint32_t global_id); + + void add_program_to_map(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Program>& program); + std::shared_ptr<::tt::tt_metal::Program> get_program_from_map(uint32_t global_id) const; + void remove_program_from_map(uint32_t global_id); + + void add_kernel_handle_to_map(uint32_t global_id, ::tt::tt_metal::KernelHandle kernel_id); + ::tt::tt_metal::KernelHandle get_kernel_handle_from_map(uint32_t global_id) const; + void remove_kernel_handle_from_map(uint32_t global_id); + + void add_kernel_to_map(uint32_t global_id, const std::shared_ptr<::tt::tt_metal::Kernel>& kernel); + std::shared_ptr<::tt::tt_metal::Kernel> get_kernel_from_map(uint32_t global_id) const; + void remove_kernel_from_map(uint32_t global_id); + + void add_cb_handle_to_map(uint32_t global_id, ::tt::tt_metal::CBHandle cb_handle); + ::tt::tt_metal::CBHandle get_cb_handle_from_map(uint32_t global_id) const; + void remove_cb_handle_from_map(uint32_t global_id); + + // Return the TraceDescriptor for a given trace_id from flatbuffer. + std::optional get_trace_by_id(uint32_t target_trace_id); + + // fromFlatBuffer that need class state + std::shared_ptr rt_args_from_flatbuffer(const FlatbufferRuntimeArgVector flatbuffer_args); + + // Workload related members -------------------- + const tt::tt_metal::flatbuffer::LightMetalBinary* parse_flatbuffer_binary(); + + LightMetalBinary binary_; // Stored binary blob + const tt::tt_metal::flatbuffer::LightMetalBinary* fb_binary_; // Parsed FlatBuffer binary + bool show_reads_ = false; // Flag to show read buffer contents + bool disable_checking_ = false; // Optionally disable equality checking in Compare command. + + // System related members ---------------------- + void setup_devices(); + void close_devices(); + + tt::tt_metal::IDevice* device_ = nullptr; + + // Object maps for storing objects by global_id + std::unordered_map> buffer_map_; + std::unordered_map> program_map_; + std::unordered_map kernel_handle_map_; + std::unordered_map> kernel_map_; + std::unordered_map cb_handle_map_; +}; + +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 612ad45bb73..0e4f20b137c 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -851,7 +851,7 @@ void detail::Program_::allocate_circular_buffers(const IDevice* device) { } } } - computed_addr = tt::align(computed_addr, device->allocator()->get_config().alignment); + computed_addr = align(computed_addr, device->allocator()->get_alignment(BufferType::DRAM)); for (const CoreRange &core_range : circular_buffer->core_ranges().ranges()) { for (CircularBufferAllocator &cb_allocator : this->cb_allocators_) { if (cb_allocator.core_range.intersects(core_range)) { diff --git a/tt_metal/impl/sub_device/sub_device_manager.cpp b/tt_metal/impl/sub_device/sub_device_manager.cpp index 17b1bfde09a..042e46ae828 100644 --- a/tt_metal/impl/sub_device/sub_device_manager.cpp +++ b/tt_metal/impl/sub_device/sub_device_manager.cpp @@ -262,6 +262,7 @@ void SubDeviceManager::populate_sub_allocators() { .dram_bank_size = 0, .dram_bank_offsets = global_allocator_config.dram_bank_offsets, .dram_unreserved_base = global_allocator_config.dram_unreserved_base, + .dram_alignment = global_allocator_config.dram_alignment, .l1_unreserved_base = global_allocator_config.l1_unreserved_base, .worker_grid = compute_cores, .worker_l1_size = global_allocator_config.l1_unreserved_base + local_l1_size_, @@ -273,7 +274,7 @@ void SubDeviceManager::populate_sub_allocators() { .worker_log_to_virtual_routing_y = global_allocator_config.worker_log_to_virtual_routing_y, .l1_bank_remap = std::move(l1_bank_remap), .compute_grid = compute_cores, - .alignment = global_allocator_config.alignment, + .l1_alignment = global_allocator_config.l1_alignment, .disable_interleaved = true}); TT_FATAL( config.l1_small_size < (config.storage_core_bank_size.has_value() @@ -281,9 +282,9 @@ void SubDeviceManager::populate_sub_allocators() { : config.worker_l1_size - config.l1_unreserved_base), "Reserved size must be less than bank size"); TT_FATAL( - config.l1_small_size % config.alignment == 0, - "Reserved size must be aligned to allocator alignment {}", - config.alignment); + config.l1_small_size % config.l1_alignment == 0, + "Reserved size must be aligned to allocator L1 alignment {}", + config.l1_alignment); // sub_devices only have compute cores for allocation for (const CoreCoord& core : corerange_to_cores(compute_cores)) { diff --git a/tt_metal/include/compute_kernel_api/bcast.h b/tt_metal/include/compute_kernel_api/bcast.h index 6e1b33f2186..5b72efe8508 100644 --- a/tt_metal/include/compute_kernel_api/bcast.h +++ b/tt_metal/include/compute_kernel_api/bcast.h @@ -9,6 +9,7 @@ #include "llk_math_binary_api.h" #include "llk_math_matmul_api.h" #include "llk_math_common.h" +#include "llk_math_unary_datacopy_api.h" #endif #ifdef TRISC_UNPACK #include "llk_unpack_AB_api.h" @@ -21,6 +22,74 @@ namespace ckernel { +template +ALWI void unary_bcast_init(uint32_t icb, uint32_t ocb) { + // Pass through uses A2D and potentially direct unpack to dest. + const auto data_copy_type = (bcast_type == BroadcastType::NONE) ? A2D : B2D; + const bool enable_unpack_to_dest = data_copy_type == A2D; + + // Will configure A & B in similar way + UNPACK((llk_unpack_A_hw_configure_disaggregated(icb))); + UNPACK((llk_unpack_A_init( + false, false /*transpose within 16x16 face*/, icb))); + + MATH((llk_math_eltwise_unary_datacopy_init( + false /*transpose of faces*/, false /*transpose within 16x16 face*/, icb))); + MATH((llk_math_pack_sync_init())); + MATH((llk_math_hw_configure_disaggregated(icb, icb))); + + PACK((llk_pack_hw_configure_disaggregated(ocb))); + PACK((llk_pack_init(ocb))); + PACK((llk_pack_dest_init())); +} + +template +ALWI void unary_bcast(uint32_t icb, uint32_t in_tile_index, uint32_t dst_tile_index) { + // Pass through uses A2D and potentially direct unpack to dest. + const auto data_copy_type = (bcast_type == BroadcastType::NONE) ? A2D : B2D; + const bool enable_unpack_to_dest = data_copy_type == A2D; + + UNPACK( + (llk_unpack_A(icb, in_tile_index))); + MATH((llk_math_eltwise_unary_datacopy( + dst_tile_index, icb))); +} + +template +void reconfigure_unary_bcast(uint32_t old_icb, uint32_t new_icb, uint32_t old_ocb, uint32_t new_ocb) { +#if defined(TRISC_MATH) || defined(TRISC_UNPACK) + // Pass through uses A2D and potentially direct unpack to dest. + const auto data_copy_type = (new_bcast_type == BroadcastType::NONE) ? A2D : B2D; + const bool enable_unpack_to_dest = data_copy_type == A2D; + const std::uint32_t new_operand_id = get_operand_id(new_icb); + const std::uint32_t old_operand_id = get_operand_id(old_icb); + bool unpacker_src_format_change = unpack_src_format[new_operand_id] != unpack_src_format[old_operand_id]; + bool unpacker_dst_format_change = unpack_dst_format[new_operand_id] != unpack_dst_format[old_operand_id]; + bool bcast_type_change = (old_bcast_type != new_bcast_type); + + if (unpacker_src_format_change || unpacker_dst_format_change) { + // Will configure A & B in similar way + UNPACK((llk_unpack_A_hw_configure_disaggregated(new_icb))); + } + + if (unpacker_src_format_change || unpacker_dst_format_change || bcast_type_change) { + UNPACK((llk_unpack_A_init( + false, false /*transpose within 16x16 face*/, new_icb))); + } + + if (unpacker_dst_format_change) { + MATH((llk_math_hw_configure_disaggregated(new_icb, new_icb))); + } + + if (unpacker_dst_format_change || bcast_type_change) { + MATH((llk_math_eltwise_unary_datacopy_init( + false /*transpose of faces*/, false /*transpose within 16x16 face*/, new_icb))); + } +#endif + + PACK((llk_pack_reconfig_data_format(old_ocb, new_ocb))); +} + /** * Shorthand template instantiation of sub_tiles_bcast. */ @@ -193,14 +262,14 @@ ALWI void any_tiles_bcast(uint32_t icb0, uint32_t icb1, uint32_t itile0, uint32_ * * | Argument | Description | Type | Valid Range | Required | * |----------------|----------------------------------------------------------|---------------|------------------------------------------------|----------| - * | tBcastDim | Broadcast dimension | BroadcastType | One of Dim::R, Dim::C, Dim::RC. | True | - * | in0_cb_id | The identifier of the circular buffer (CB) containing A | uint32_t | 0 to 31 | True | - * | in1_cb_id | The indentifier of the circular buffer (CB) containing B | uint32_t | 0 to 31 | True | - * | in0_tile_index | The index of tile A within the first CB | uint32_t | Must be less than the size of the CB | True | - * | in1_tile_index | The index of tile B within the second CB | uint32_t | Must be less than the size of the CB | True | + * | tBcastDim | Broadcast dimension | BroadcastType | One of Dim::R, Dim::C, Dim::RC. | True | + * | in0_cb_id | The identifier of the circular buffer (CB) containing A | uint32_t | 0 to 31 | True | + * | in1_cb_id | The indentifier of the circular buffer (CB) containing B | uint32_t | 0 to 31 | True | + * | in0_tile_index | The index of tile A within the first CB | uint32_t | Must be less than the size of the CB | True | + * | in1_tile_index | The index of tile B within the second CB | uint32_t | Must be less than the size of the CB | True | * | dst_tile_index | The index of the tile in DST REG for the result C | uint32_t | Must be less than the acquired size of DST REG | True | */ - // clang-format on +// clang-format on template ALWI void add_tiles_bcast(uint32_t icb0, uint32_t icb1, uint32_t itile0, uint32_t itile1, uint32_t idst) { any_tiles_bcast(icb0, icb1, itile0, itile1, idst); diff --git a/tt_metal/include/compute_kernel_api/eltwise_binary.h b/tt_metal/include/compute_kernel_api/eltwise_binary.h index 2432515b14b..7d6254ec541 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_binary.h +++ b/tt_metal/include/compute_kernel_api/eltwise_binary.h @@ -26,8 +26,8 @@ namespace ckernel { * | icb1 | The identifier of the circular buffer (CB) containing B | uint32_t | 0 to 31 | True | * | ocb | The identifier of the circular buffer (CB) containing output | uint32_t | 0 to 31, defaults to CB 16 | True | */ - // clang-format on -ALWI void binary_op_init_common(uint32_t icb0, uint32_t icb1, uint32_t ocb = 16) { +// clang-format on +ALWI void binary_op_init_common(uint32_t icb0, uint32_t icb1, uint32_t ocb) { UNPACK((llk_unpack_AB_hw_configure_disaggregated(icb0, icb1))); UNPACK((llk_unpack_AB_init(icb0, icb1))); @@ -49,7 +49,7 @@ ALWI void mul_tiles_init_f() { MATH((llk_math_eltwise_binary_init())); UNPACK((llk_unpack_AB_init(icb0, icb1))); } @@ -71,8 +71,8 @@ ALWI void add_tiles_init_nof() { MATH((llk_math_eltwise_binary_init(0 /*transpose*/, acc_to_dest))); UNPACK((llk_unpack_AB_init(icb0, icb1, 0 /*transpose*/, acc_to_dest))); } @@ -94,8 +94,8 @@ ALWI void sub_tiles_init_nof() { MATH((llk_math_eltwise_binary_init(0 /*transpose*/, acc_to_dest))); UNPACK((llk_unpack_AB_init(icb0, icb1, 0 /*transpose*/, acc_to_dest))); } @@ -188,15 +188,15 @@ ALWI void sub_tiles(uint32_t icb0, uint32_t icb1, uint32_t itile0, uint32_t itil * eltwise_binary_op_type: the binary operation type */ template -ALWI void binary_op_specific_init() // TODO(AP): better naming +ALWI void binary_op_specific_init(uint32_t icb0, uint32_t icb1) // TODO(AP): better naming { if constexpr (full_init) { if constexpr (eltwise_binary_op_type == ELWADD) { - add_tiles_init(); + add_tiles_init(icb0, icb1); } else if constexpr (eltwise_binary_op_type == ELWSUB) { - sub_tiles_init(); + sub_tiles_init(icb0, icb1); } else if constexpr (eltwise_binary_op_type == ELWMUL) { - mul_tiles_init(); + mul_tiles_init(icb0, icb1); } } else { if constexpr (eltwise_binary_op_type == ELWADD) { diff --git a/tt_metal/include/compute_kernel_api/pack_untilize.h b/tt_metal/include/compute_kernel_api/pack_untilize.h index b50057c7b00..4520936fc3d 100644 --- a/tt_metal/include/compute_kernel_api/pack_untilize.h +++ b/tt_metal/include/compute_kernel_api/pack_untilize.h @@ -76,6 +76,11 @@ template < bool narrow_row = false, std::uint32_t row_num_datums = TILE_C_DIM> ALWI void pack_untilize_dst_init_short(uint32_t ocb, uint32_t face_r_dim = 16, uint32_t num_faces = 4) { +#ifndef ARCH_GRAYSKULL + // A workaround for tt-metal#17132. This is not needed for Grayskull, + // as it breaks the packer. Should be addressed more systematically. + PACK((llk_pack_untilize_hw_configure_disaggregated(ocb, face_r_dim, num_faces))); +#endif PACK((llk_pack_untilize_init( ocb, face_r_dim, num_faces))); PACK((llk_init_packer_dest_offset_registers())); diff --git a/tt_metal/jit_build/build.cpp b/tt_metal/jit_build/build.cpp index 0e333824a95..8876c9a6915 100644 --- a/tt_metal/jit_build/build.cpp +++ b/tt_metal/jit_build/build.cpp @@ -168,12 +168,10 @@ void JitBuildEnv::init( "tt_metal/hw/inc/debug " + "-I" + this->root_ + "tt_metal/hw/inc/" + this->aliased_arch_name_ + " " + "-I" + this->root_ + "tt_metal/hw/inc/" + this->aliased_arch_name_ + "/" + this->arch_name_ + "_defines " + "-I" + this->root_ + "tt_metal/hw/inc/" + - this->aliased_arch_name_ + "/noc " + "-I" + this->root_ + "tt_metal/third_party/umd/device/api " + - "-I" + this->root_ + "tt_metal/third_party/umd/device/" + this->arch_name_ + " " + // TODO(fixme) - "-I" + this->root_ + "tt_metal/hw/ckernels/" + this->arch_name_ + "/metal/common " + "-I" + - this->root_ + "tt_metal/hw/ckernels/" + this->arch_name_ + "/metal/llk_io " + "-I" + this->root_ + - "tt_metal/third_party/tt_llk_" + this->arch_name_ + - "/common/inc " + // TODO(fixme) datamovement fw shouldn't read this + this->aliased_arch_name_ + "/noc " + "-I" + this->root_ + "tt_metal/hw/ckernels/" + + this->arch_name_ + "/metal/common " + "-I" + this->root_ + "tt_metal/hw/ckernels/" + + this->arch_name_ + "/metal/llk_io " + "-I" + this->root_ + "tt_metal/third_party/tt_llk_" + + this->arch_name_ + "/common/inc " + // TODO(fixme) datamovement fw shouldn't read this "-I" + this->root_ + "tt_metal/api/" + this->aliased_arch_name_ + " " + "-I" + this->root_ + "tt_metal/api/" + this->aliased_arch_name_ + "/tt-metalium " + "-I" + this->root_ + "tt_metal/api/tt-metalium/ " + "-I" + this->root_ + "tt_metal/api/ " + "-I" + this->root_ + diff --git a/tt_metal/kernels/compute/eltwise_binary.cpp b/tt_metal/kernels/compute/eltwise_binary.cpp index a45bc0efcc1..55a190b796f 100644 --- a/tt_metal/kernels/compute/eltwise_binary.cpp +++ b/tt_metal/kernels/compute/eltwise_binary.cpp @@ -26,9 +26,9 @@ void MAIN { #if not defined ELTWISE_DEST_REUSE_TYPE #ifdef FULL_INIT - binary_op_specific_init(); + binary_op_specific_init(cb_in0, cb_in1); #else - binary_op_specific_init(); + binary_op_specific_init(cb_in0, cb_in1); #endif #endif diff --git a/tt_metal/programming_examples/add_2_integers_in_compute/kernels/compute/add_2_tiles.cpp b/tt_metal/programming_examples/add_2_integers_in_compute/kernels/compute/add_2_tiles.cpp index 87953016271..31824737be9 100644 --- a/tt_metal/programming_examples/add_2_integers_in_compute/kernels/compute/add_2_tiles.cpp +++ b/tt_metal/programming_examples/add_2_integers_in_compute/kernels/compute/add_2_tiles.cpp @@ -13,7 +13,7 @@ void MAIN { constexpr auto cb_out0 = tt::CBIndex::c_16; binary_op_init_common(cb_in0, cb_in1, cb_out0); - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); // wait for a block of tiles in each of input CBs cb_wait_front(cb_in0, 1); diff --git a/tt_metal/programming_examples/contributed/vecadd/kernels/add.cpp b/tt_metal/programming_examples/contributed/vecadd/kernels/add.cpp index 1e6ea4da2a7..100320cd0d0 100644 --- a/tt_metal/programming_examples/contributed/vecadd/kernels/add.cpp +++ b/tt_metal/programming_examples/contributed/vecadd/kernels/add.cpp @@ -32,7 +32,7 @@ void MAIN { // And we are going to add tiles. This function is only called if we ever need to // switch operation to something else. Since we are only adding tiles, this function // is only called once before the loop. - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); // Loop over all the tiles and perform the computation for (uint32_t i = 0; i < n_tiles; i++) { diff --git a/tt_metal/programming_examples/sharding/shard_data_rm.cpp b/tt_metal/programming_examples/sharding/shard_data_rm.cpp index 47857f0f083..e5c12f3dd16 100644 --- a/tt_metal/programming_examples/sharding/shard_data_rm.cpp +++ b/tt_metal/programming_examples/sharding/shard_data_rm.cpp @@ -45,7 +45,7 @@ int main(int argc, char** argv) { uint32_t input_unit_size = sizeof(uint32_t); uint32_t shard_width_bytes = shard_width * data_size; uint32_t num_units_per_row = shard_width * input_unit_size; - uint32_t padded_offset_bytes = align(input_unit_size, device->allocator()->get_config().alignment); + uint32_t padded_offset_bytes = align(input_unit_size, device->allocator()->get_alignment(BufferType::DRAM)); // configure and create interleaved DRAM buffer to insert source data into uint32_t src_buffer_size = input_unit_size * num_values / data_size; diff --git a/tt_metal/programming_examples/vecadd_multi_core/kernels/add_multi_core.cpp b/tt_metal/programming_examples/vecadd_multi_core/kernels/add_multi_core.cpp index 41771678c2d..d38a6d2e30a 100644 --- a/tt_metal/programming_examples/vecadd_multi_core/kernels/add_multi_core.cpp +++ b/tt_metal/programming_examples/vecadd_multi_core/kernels/add_multi_core.cpp @@ -33,7 +33,7 @@ void MAIN { // And we are going to add tiles. This function is only called if we ever // need to switch operation to something else. Since we are only adding // tiles, this function is only called once before the loop. - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); // Calculate the range of tiles this core should process const uint32_t tiles_per_core = n_tiles; diff --git a/tt_metal/programming_examples/vecadd_sharding/kernels/add_sharding.cpp b/tt_metal/programming_examples/vecadd_sharding/kernels/add_sharding.cpp index f607877abfa..1cbc9679841 100644 --- a/tt_metal/programming_examples/vecadd_sharding/kernels/add_sharding.cpp +++ b/tt_metal/programming_examples/vecadd_sharding/kernels/add_sharding.cpp @@ -32,7 +32,7 @@ void MAIN { // And we are going to add tiles. This function is only called if we ever // need to switch operation to something else. Since we are only adding // tiles, this function is only called once before the loop. - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); // Loop over the assigned tiles and perform the computation for (uint32_t i = 0; i < num_tile; i++) { diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index 9cc87b5f76b..9fd3e2d93d1 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit 9cc87b5f76baccdb6d1a96e7b2731b13cae41060 +Subproject commit 9fd3e2d93d1532373f52e11e963de40c1cdf9a55 diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index 61f51430ac4..0ec3177bfc2 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit 61f51430ac48c4b8efc60c1bbc19d87b1420269b +Subproject commit 0ec3177bfc262f7edf6cfc19531ecb8f669895d2 diff --git a/tt_metal/tools/CMakeLists.txt b/tt_metal/tools/CMakeLists.txt index c4a51b0ea1a..3509710519a 100644 --- a/tt_metal/tools/CMakeLists.txt +++ b/tt_metal/tools/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/profiler) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/watcher_dump) +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/lightmetal_runner) set(TOOLS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/memset.cpp) diff --git a/tt_metal/tools/lightmetal_runner/CMakeLists.txt b/tt_metal/tools/lightmetal_runner/CMakeLists.txt new file mode 100644 index 00000000000..7d540adeb81 --- /dev/null +++ b/tt_metal/tools/lightmetal_runner/CMakeLists.txt @@ -0,0 +1,22 @@ +add_executable(lightmetal_runner ${CMAKE_CURRENT_SOURCE_DIR}/lightmetal_runner.cpp) +target_link_libraries( + lightmetal_runner + PRIVATE + tt_metal + FlatBuffers::FlatBuffers +) +target_include_directories( + lightmetal_runner + PRIVATE + ${PROJECT_SOURCE_DIR} + ${PROJECT_SOURCE_DIR}/tt_metal + ${PROJECT_SOURCE_DIR}/tt_metal/common + ${CMAKE_CURRENT_SOURCE_DIR} +) + +set_target_properties( + lightmetal_runner + PROPERTIES + RUNTIME_OUTPUT_DIRECTORY + ${PROJECT_BINARY_DIR}/tools +) diff --git a/tt_metal/tools/lightmetal_runner/lightmetal_runner.cpp b/tt_metal/tools/lightmetal_runner/lightmetal_runner.cpp new file mode 100644 index 00000000000..c65b3a912d6 --- /dev/null +++ b/tt_metal/tools/lightmetal_runner/lightmetal_runner.cpp @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "impl/lightmetal/lightmetal_replay.hpp" +#include +#include +#include + +using namespace tt; + +// This is a standalone tool for executing Light Metal Binary files. Light Metal Binary files +// are generated by the LightMetalBeginCapture() and LightMetalEndCapture() APIs and contain +// a serialized representation of +// - Host API calls +// - Device CQ workload/traces +// - (Future support) Precompiled programs/kernels for fast deployment +// +// Usage: +// lightmetal_runner +// +// Arguments: +// - Path to the Light Metal Binary file to be executed. +// +// This tool reads the specified binary file, transfers ownership of its contents to the +// light metal binary replay executor, and runs it, returning pass (0) or fail (1). + +int main(int argc, char* argv[]) { + // Process cmdline arguments + std::string program_filename = argv[0]; + TT_FATAL(argc == 2, "Invalid number of supplied arguments. Usage: {} ", program_filename.c_str()); + std::string binary_filename = argv[1]; + + // Read the Light Metal Binary file into blob, transfer ownership and execute it. + LightMetalBinary binary = LightMetalBinary::load_from_file(binary_filename); + tt::tt_metal::LightMetalReplay lm_replay(std::move(binary)); + + if (!lm_replay.execute_binary()) { + log_fatal("Light Metal Binary {} failed to execute or encountered errors.", binary_filename); + return 1; + } else { + log_info(tt::LogMetalTrace, "Light Metal Binary {} executed successfully", binary_filename); + return 0; + } +} diff --git a/tt_metal/tools/profiler/device_post_proc_config.py b/tt_metal/tools/profiler/device_post_proc_config.py index 5303ec07e4a..b3e5f8c1697 100644 --- a/tt_metal/tools/profiler/device_post_proc_config.py +++ b/tt_metal/tools/profiler/device_post_proc_config.py @@ -24,6 +24,28 @@ class default_setup(metaclass=MergeMetaclass): ] timerAnalysis = { + "device_kernel_first_to_last_start": { + "across": "ops", + "type": "op_first_last", + "start": { + "core": "ANY", + "risc": "ANY", + "zone_phase": "ZONE_START", + "zone_name": [f"{risc}-KERNEL" for risc in riscTypes], + }, + "end": { + "core": "ANY", + "risc": "ANY", + "zone_phase": "ZONE_START", + "zone_name": [f"{risc}-KERNEL" for risc in riscTypes], + }, + }, + "device_kernel_duration_per_core": { + "across": "ops", + "type": "op_core_first_last", + "start": {"core": "ANY", "risc": "ANY", "zone_name": [f"{risc}-KERNEL" for risc in riscTypes]}, + "end": {"core": "ANY", "risc": "ANY", "zone_name": [f"{risc}-KERNEL" for risc in riscTypes]}, + }, "device_fw_duration": { "across": "ops", "type": "op_first_last", diff --git a/tt_metal/tools/profiler/process_device_log.py b/tt_metal/tools/profiler/process_device_log.py index baf986abbd7..3a86210b9c5 100755 --- a/tt_metal/tools/profiler/process_device_log.py +++ b/tt_metal/tools/profiler/process_device_log.py @@ -356,8 +356,8 @@ def get_ops(timeseries): opCores[core] = (timerID,) if len(ts) == 4: timerID, tsValue, statData, risc = ts - if (risc == "BRISC" and timerID["zone_name"] == "BRISC-FW" and timerID["type"] == "ZONE_START") or ( - risc == "ERISC" and timerID["zone_name"] == "ERISC-FW" and timerID["type"] == "ZONE_START" + if (risc == "BRISC" and timerID["zone_name"] == "BRISC-FW" and timerID["type"] == "ZONE_END") or ( + risc == "ERISC" and timerID["zone_name"] == "ERISC-FW" and timerID["type"] == "ZONE_END" ): opIsDone = True ops[-1]["timeseries"].append(ts) @@ -436,16 +436,19 @@ def translate_metaData(metaData, core, risc): def determine_conditions(timerID, metaData, analysis): currCore = analysis["start"]["core"] if "core" in analysis["start"].keys() else None currRisc = analysis["start"]["risc"] - currStart = (timerID["zone_name"],) + translate_metaData(metaData, currCore, currRisc) + currPhase = (timerID["type"],) if "zone_phase" in analysis["start"].keys() else (None,) + currStart = (timerID["zone_name"],) + currPhase + translate_metaData(metaData, currCore, currRisc) currCore = analysis["end"]["core"] if "core" in analysis["end"].keys() else None currRisc = analysis["end"]["risc"] - currEnd = (timerID["zone_name"],) + translate_metaData(metaData, currCore, currRisc) + currPhase = (timerID["type"],) if "zone_phase" in analysis["end"].keys() else (None,) + currEnd = (timerID["zone_name"],) + currPhase + translate_metaData(metaData, currCore, currRisc) if type(analysis["start"]["zone_name"]) == list: desStart = [ ( zoneName, + analysis["start"]["zone_phase"] if "zone_phase" in analysis["start"].keys() else None, analysis["start"]["core"] if "core" in analysis["start"].keys() else None, analysis["start"]["risc"], ) @@ -455,6 +458,7 @@ def determine_conditions(timerID, metaData, analysis): desStart = [ ( analysis["start"]["zone_name"], + analysis["start"]["zone_phase"] if "zone_phase" in analysis["start"].keys() else None, analysis["start"]["core"] if "core" in analysis["start"].keys() else None, analysis["start"]["risc"], ) @@ -464,6 +468,7 @@ def determine_conditions(timerID, metaData, analysis): desEnd = [ ( zoneName, + analysis["end"]["zone_phase"] if "zone_phase" in analysis["end"].keys() else None, analysis["end"]["core"] if "core" in analysis["end"].keys() else None, analysis["end"]["risc"], ) @@ -473,6 +478,7 @@ def determine_conditions(timerID, metaData, analysis): desEnd = [ ( analysis["end"]["zone_name"], + analysis["end"]["zone_phase"] if "zone_phase" in analysis["end"].keys() else None, analysis["end"]["core"] if "core" in analysis["end"].keys() else None, analysis["end"]["risc"], ) @@ -506,7 +512,6 @@ def first_last_analysis(timeseries, analysis): ) ) break - return durations @@ -518,6 +523,22 @@ def op_first_last_analysis(riscData, analysis): return first_last_analysis(riscData["timeseries"], analysis) +def op_core_first_last_analysis(riscData, analysis): + core_ops = {} + durations = [] + for ts in riscData["timeseries"]: + assert len(ts) == 5 + core = ts[4] + if core in core_ops: + core_ops[core].append(ts) + else: + core_ops[core] = [ts] + for core, timeseries in core_ops.items(): + durations.append(first_last_analysis(timeseries, analysis)[0]) + + return durations + + def get_duration(riscData, analysis): totalDuration = 0 for index, (timerID, timestamp, statData, risc, core) in enumerate(riscData["timeseries"]): @@ -564,6 +585,8 @@ def timeseries_analysis(riscData, name, analysis): tmpList = session_first_last_analysis(riscData, analysis) elif analysis["type"] == "op_first_last": tmpList = op_first_last_analysis(riscData, analysis) + elif analysis["type"] == "op_core_first_last": + tmpList = op_core_first_last_analysis(riscData, analysis) elif analysis["type"] == "sum": tmpList = get_duration(riscData, analysis) else: diff --git a/tt_metal/tools/profiler/process_ops_logs.py b/tt_metal/tools/profiler/process_ops_logs.py index 8c3c5de1318..c7e5970cdd3 100755 --- a/tt_metal/tools/profiler/process_ops_logs.py +++ b/tt_metal/tools/profiler/process_ops_logs.py @@ -53,6 +53,10 @@ "OP TO OP LATENCY [ns]", "DEVICE FW DURATION [ns]", "DEVICE KERNEL DURATION [ns]", + "DEVICE KERNEL DURATION PER CORE MIN [ns]", + "DEVICE KERNEL DURATION PER CORE MAX [ns]", + "DEVICE KERNEL DURATION PER CORE AVG [ns]", + "DEVICE KERNEL FIRST TO LAST START [ns]", "DEVICE BRISC KERNEL DURATION [ns]", "DEVICE NCRISC KERNEL DURATION [ns]", "DEVICE TRISC0 KERNEL DURATION [ns]", @@ -349,10 +353,11 @@ def append_device_data(ops, traceReplays, logFolder): cores.add(core) deviceOp["core_usage"] = {"count": len(cores), "cores": [str(core) for core in cores]} deviceOp["device_time"] = { - analysis: data["series"] for analysis, data in deviceOpTime["analysis"].items() + analysis: {"series": data["series"], "stats": data["stats"]} + for analysis, data in deviceOpTime["analysis"].items() } for analysis, data in deviceOp["device_time"].items(): - for sample in data: + for sample in data["series"]: sample["duration_ns"] = sample["duration_cycles"] * 1000 / freq traceOps = {} @@ -422,7 +427,8 @@ def get_device_data_generate_report( cores.add(core) deviceOp["core_usage"] = {"count": len(cores), "cores": [str(core) for core in cores]} deviceOp["device_time"] = { - analysis: data["series"] for analysis, data in deviceOpTime["analysis"].items() + analysis: {"series": data["series"], "stats": data["stats"]} + for analysis, data in deviceOpTime["analysis"].items() } if "run_host_id" in timeID.keys(): @@ -430,15 +436,26 @@ def get_device_data_generate_report( else: deviceOp["global_call_count"] = i for analysis, data in deviceOp["device_time"].items(): - for sample in data: + for sample in data["series"]: sample["duration_ns"] = sample["duration_cycles"] * 1000 / freq deviceOps[device].append(deviceOp) rowDict = {csv_header_format("global_call_count"): deviceOp["global_call_count"]} - for analysis, analysisData in deviceOp["device_time"].items(): - headerField = f"{csv_header_format(analysis)} [ns]" - assert len(analysisData) == 1, "Unexpected device data format" - rowDict[headerField] = f"{analysisData[0]['duration_ns']:.0f}" + for analysis, data in deviceOp["device_time"].items(): + analysisData = data["series"] + analysisStats = data["stats"] + if "core" in analysis: + assert len(analysisData) >= 1, "Unexpected device data format" + headerField = f"{csv_header_format(analysis)} MIN [ns]" + rowDict[headerField] = f"{analysisStats['Min']:.0f}" + headerField = f"{csv_header_format(analysis)} MAX [ns]" + rowDict[headerField] = f"{analysisStats['Max']:.0f}" + headerField = f"{csv_header_format(analysis)} AVG [ns]" + rowDict[headerField] = f"{analysisStats['Average']:.0f}" + else: + headerField = f"{csv_header_format(analysis)} [ns]" + assert len(analysisData) == 1, "Unexpected device data format" + rowDict[headerField] = f"{analysisData[0]['duration_ns']:.0f}" if analysis == "device_fw_duration": rowDict["DEVICE FW START CYCLE"] = analysisData[0]["start_cycle"] rowDict["DEVICE FW END CYCLE"] = analysisData[0]["end_cycle"] @@ -646,10 +663,21 @@ def row_compare(row): if "device_time" in opData.keys(): assert "device_id" in opData.keys(), "Op has device data without device_id" deviceID = opData["device_id"] - for analysis, analysisData in opData["device_time"].items(): - headerField = f"{csv_header_format(analysis)} [ns]" - assert len(analysisData) == 1, "Unexpected device data format" - rowDict[headerField] = f"{analysisData[0]['duration_ns']:.0f}" + for analysis, data in opData["device_time"].items(): + analysisData = data["series"] + analysisStats = data["stats"] + if "core" in analysis: + assert len(analysisData) >= 1, "Unexpected device data format" + headerField = f"{csv_header_format(analysis)} MIN [ns]" + rowDict[headerField] = f"{analysisStats['Min']:.0f}" + headerField = f"{csv_header_format(analysis)} MAX [ns]" + rowDict[headerField] = f"{analysisStats['Max']:.0f}" + headerField = f"{csv_header_format(analysis)} AVG [ns]" + rowDict[headerField] = f"{analysisStats['Average']:.0f}" + else: + headerField = f"{csv_header_format(analysis)} [ns]" + assert len(analysisData) == 1, "Unexpected device data format" + rowDict[headerField] = f"{analysisData[0]['duration_ns']:.0f}" if analysis == "device_fw_duration": rowDict["DEVICE FW START CYCLE"] = analysisData[0]["start_cycle"] rowDict["DEVICE FW END CYCLE"] = analysisData[0]["end_cycle"] diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 2ccb761ed09..f1a36ce8f7a 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -35,6 +35,7 @@ #include "tracy/Tracy.hpp" #include +#include "lightmetal/host_api_capture_helpers.hpp" #include "llrt.hpp" @@ -933,7 +934,12 @@ bool CloseDevice(IDevice* device) { return tt::DevicePool::instance().close_device(device_id); } -Program CreateProgram() { return Program(); } +Program CreateProgram() { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + auto program = Program(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateProgram, program); + return program; +} KernelHandle CreateDataMovementKernel( Program& program, @@ -1019,7 +1025,8 @@ KernelHandle CreateKernel( const std::string& file_name, const std::variant& core_spec, const std::variant& config) { - return std::visit( + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + KernelHandle kernel = std::visit( [&](auto&& cfg) -> KernelHandle { CoreRangeSet core_ranges = GetCoreRangeSet(core_spec); KernelSource kernel_src(file_name, KernelSource::FILE_PATH); @@ -1033,6 +1040,9 @@ KernelHandle CreateKernel( } }, config); + + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateKernel, kernel, program, file_name, core_spec, config); + return kernel; } KernelHandle CreateKernelFromString( @@ -1060,8 +1070,11 @@ CBHandle CreateCircularBuffer( Program& program, const std::variant& core_spec, const CircularBufferConfig& config) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); CoreRangeSet core_ranges = GetCoreRangeSet(core_spec); - return program.add_circular_buffer(core_ranges, config); + auto cb_handle = program.add_circular_buffer(core_ranges, config); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateCircularBuffer, cb_handle, program, core_spec, config); + return cb_handle; } const CircularBufferConfig& GetCircularBufferConfig(Program& program, CBHandle cb_handle) { @@ -1141,7 +1154,8 @@ GlobalSemaphore CreateGlobalSemaphore( } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { - return Buffer::create( + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + auto buffer = Buffer::create( config.device, config.size, config.page_size, @@ -1150,6 +1164,9 @@ std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config) { std::nullopt, std::nullopt, std::nullopt); + + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureCreateBuffer, buffer, config); + return buffer; } std::shared_ptr CreateBuffer(const InterleavedBufferConfig& config, DeviceAddr address) { return Buffer::create( @@ -1208,7 +1225,11 @@ std::shared_ptr CreateBuffer(const ShardedBufferConfig& config, SubDevic sub_device_id); } -void DeallocateBuffer(Buffer& buffer) { buffer.deallocate(); } +void DeallocateBuffer(Buffer& buffer) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureDeallocateBuffer, buffer); + buffer.deallocate(); +} void AssignGlobalBufferToProgram(const std::shared_ptr& buffer, Program& program) { detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); @@ -1220,6 +1241,8 @@ void SetRuntimeArgs( KernelHandle kernel_id, const std::variant& core_spec, stl::Span runtime_args) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureSetRuntimeArgsUint32, program, kernel_id, core_spec, runtime_args); ZoneScoped; std::visit([&](auto&& core_spec) { SetRuntimeArgsImpl(program, kernel_id, core_spec, runtime_args); }, core_spec); } @@ -1246,7 +1269,9 @@ void SetRuntimeArgs( const std::shared_ptr& kernel, const std::variant& core_spec, const std::shared_ptr& runtime_args) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); detail::DispatchStateCheck(not device->using_slow_dispatch()); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureSetRuntimeArgs, device, kernel, core_spec, runtime_args); SetRuntimeArgsImpl(kernel, core_spec, std::move(runtime_args), false); } @@ -1289,22 +1314,51 @@ uint32_t BeginTraceCapture(IDevice* device, const uint8_t cq_id) { return tid; } -void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) { device->end_trace(cq_id, tid); } +void EndTraceCapture(IDevice* device, const uint8_t cq_id, const uint32_t tid) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + device->end_trace(cq_id, tid); + // When light metal tracing is enabled, TraceDescriptor will be serialized via end_trace() and this + // will serialize the LightMetalLoadTraceId call to be used during replay to load trace back to device. + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureLoadTrace, device, cq_id, tid); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, true); // blocking=true +} void ReplayTrace(IDevice* device, const uint8_t cq_id, const uint32_t tid, const bool blocking) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReplayTrace, device, cq_id, tid, blocking); device->replay_trace(cq_id, tid, blocking); } -void ReleaseTrace(IDevice* device, const uint32_t tid) { device->release_trace(tid); } +void ReleaseTrace(IDevice* device, const uint32_t tid) { + LIGHT_METAL_TRACE_FUNCTION_ENTRY(); + LIGHT_METAL_TRACE_FUNCTION_CALL(CaptureReleaseTrace, device, tid); + device->release_trace(tid); +} -// Light Metal Begin/End Capture APIs are stubs for now, filled in soon. +// This is nop if compile time define not set. void LightMetalBeginCapture() { - log_warning(tt::LogMetalTrace, "Begin LightMetalBinary Capture - not yet implemented."); +#if defined(TT_ENABLE_LIGHT_METAL_TRACE) && (TT_ENABLE_LIGHT_METAL_TRACE == 1) + log_debug(tt::LogMetalTrace, "Begin LightMetalBinary Capture"); + auto& lm_capture_ctx = LightMetalCaptureContext::get(); + lm_capture_ctx.reset(); // Clear previous traces if any, ensure tracing disabled + lm_capture_ctx.set_tracing(true); // Enable tracing +#else + log_warning(tt::LogMetalTrace, "TT_ENABLE_LIGHT_METAL_TRACE!=1, ignoring LightMetalBeginCapture()"); +#endif } +// This is nop if compile time define not set, return empty vector. LightMetalBinary LightMetalEndCapture() { - log_warning(tt::LogMetalTrace, "End LightMetalBinary Capture - not yet implemented."); +#if defined(TT_ENABLE_LIGHT_METAL_TRACE) && (TT_ENABLE_LIGHT_METAL_TRACE == 1) + log_debug(tt::LogMetalTrace, "End LightMetalBinary Capture"); + auto& lm_capture_ctx = LightMetalCaptureContext::get(); + TT_ASSERT(lm_capture_ctx.is_tracing(), "Light Metal Capture was not enabled."); + lm_capture_ctx.set_tracing(false); // Disable tracing + return lm_capture_ctx.create_light_metal_binary(); +#else + log_warning(tt::LogMetalTrace, "TT_ENABLE_LIGHT_METAL_TRACE!=1, ignoring LightMetalEndCapture()"); return {}; +#endif } void LoadTrace(IDevice* device, const uint8_t cq_id, const uint32_t trace_id, const TraceDescriptor& trace_desc) { diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index c22ff66ceb9..e9e3e010ef1 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -17,6 +17,7 @@ set(TTNN_BASE_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/graph/graph_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/creation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp ) set(TTNN_OP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp @@ -78,8 +79,8 @@ set(TTNN_OP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/permute/device/permute_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/permute/device/permute_rm_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/permute/device/permute_tiled_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/repeat_interleave/repeat_interleave.cpp @@ -588,8 +589,6 @@ set(TTNN_OP_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/expand.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_rm_program_factory.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/dropout/device/dropout_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/dropout/device/dropout_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/dropout/dropout.cpp @@ -684,7 +683,6 @@ set(TTNN_PUBLIC_INCLUDE_DIRS set(TTNN_PUBLIC_LINK_LIBRARIES metal_common_libs Metalium::Metal - Boost::container xtensor xtensor-blas xtl @@ -759,11 +757,6 @@ function(add_ttnn_sublibrary SUBLIBRARY_NAME) add_library(${SUBLIBRARY_NAME} OBJECT ${ARGN}) endif() TT_ENABLE_UNITY_BUILD(${SUBLIBRARY_NAME}) - if(WITH_PYTHON_BINDINGS) - target_compile_definitions(${SUBLIBRARY_NAME} PUBLIC TTNN_WITH_PYTHON_BINDINGS=1) - else() - target_compile_definitions(${SUBLIBRARY_NAME} PUBLIC TTNN_WITH_PYTHON_BINDINGS=0) - endif() target_include_directories(${SUBLIBRARY_NAME} PUBLIC ${TTNN_PUBLIC_INCLUDE_DIRS}) target_link_libraries(${SUBLIBRARY_NAME} PUBLIC ${TTNN_PUBLIC_LINK_LIBRARIES}) target_link_directories(${SUBLIBRARY_NAME} PUBLIC ${TTNN_PUBLIC_LINK_DIRS}) @@ -827,12 +820,6 @@ target_compile_options( -fno-var-tracking ) -if(WITH_PYTHON_BINDINGS) - target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=1) -else() - target_compile_definitions(ttnn PUBLIC TTNN_WITH_PYTHON_BINDINGS=0) -endif() - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") target_compile_definitions(ttnn PUBLIC DISABLE_NAMESPACE_STATIC_ASSERT) endif() diff --git a/ttnn/cpp/pybind11/decorators.hpp b/ttnn/cpp/pybind11/decorators.hpp index 245ed71d5cb..00153d8b791 100644 --- a/ttnn/cpp/pybind11/decorators.hpp +++ b/ttnn/cpp/pybind11/decorators.hpp @@ -6,10 +6,10 @@ #include #include - #include #include "ttnn/decorators.hpp" +#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support. #include "ttnn/types.hpp" namespace py = pybind11; diff --git a/ttnn/cpp/pybind11/device.cpp b/ttnn/cpp/pybind11/device.cpp index 7a9fd64519a..0a3e9b6c1dd 100644 --- a/ttnn/cpp/pybind11/device.cpp +++ b/ttnn/cpp/pybind11/device.cpp @@ -8,6 +8,7 @@ #include #include +#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support. #include #include #include diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 9536f3173e3..23c47b0f8c3 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -9,6 +9,7 @@ #include #include +#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support. #include "ttnn/tensor/tensor.hpp" #include #include @@ -340,7 +341,7 @@ Tensor convert_python_tensor_to_tt_tensor( py_data_ptr, tensor_spec, device, force_disable_borrow, on_creation_callback, on_destruction_callback); if (device) { - output = output.to(device, memory_config); + output = output.to_device(device, memory_config); } output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); @@ -643,10 +644,10 @@ auto parse_external_operation( } // namespace detail void pytensor_module_types(py::module& m_tensor) { - // Tensor constructors that accept device and .to(device) function use keep alive call policy to communicate that + // Tensor constructors that accept device and .to_device() function use keep alive call policy to communicate that // Device needs to outlive Tensor. This is because when tensors on device are destroyed they need to deallocate // their buffers via device. keep_alive increases the ref count of the Device object being passed into the - // constructor and .to() function. For additional info see: + // constructor and .to_device() function. For additional info see: // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#keep-alive auto pyTensor = py::class_(m_tensor, "Tensor", R"doc( @@ -968,7 +969,7 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast(&Tensor::to_device, py::const_), py::arg("device").noconvert(), py::arg("mem_config").noconvert() = MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED}, py::arg("cq_id") = ttnn::DefaultQueueId, @@ -1002,7 +1003,7 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast(&Tensor::to_device, py::const_), py::arg("mesh_device").noconvert(), py::arg("mem_config").noconvert() = MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED}, py::arg("cq_id") = ttnn::DefaultQueueId, @@ -1089,7 +1090,7 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast(&Tensor::to_layout, py::const_), py::arg("target_layout").noconvert(), py::arg("worker") = nullptr, R"doc( @@ -1113,7 +1114,7 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast(&Tensor::to_layout, py::const_), py::arg("target_layout").noconvert(), py::arg("mesh_device") = nullptr, R"doc( diff --git a/ttnn/cpp/pybind11/small_vector_caster.hpp b/ttnn/cpp/pybind11/small_vector_caster.hpp new file mode 100644 index 00000000000..37dd3d478ec --- /dev/null +++ b/ttnn/cpp/pybind11/small_vector_caster.hpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include +#include + +#include + +namespace PYBIND11_NAMESPACE { +namespace detail { +template +struct type_caster> : list_caster, T> { +}; +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/ttnn/cpp/pybind11/types.hpp b/ttnn/cpp/pybind11/types.hpp index 4c5d454a912..3ab9a55eadc 100644 --- a/ttnn/cpp/pybind11/types.hpp +++ b/ttnn/cpp/pybind11/types.hpp @@ -8,7 +8,10 @@ #include #include +#include + #include "export_enum.hpp" +#include "small_vector_caster.hpp" // NOLINT - for pybind11 SmallVector binding support. #include "ttnn/tensor/tensor.hpp" #include "ttnn/types.hpp" #include "ttnn/operations/data_movement/bcast/bcast_types.hpp" diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 34f77e9276e..831c1f4cbd5 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -116,7 +116,7 @@ Tensor aggregate_as_tensor( } } auto storage = - MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs, /*mesh_buffer_=*/nullptr}; + MultiDeviceStorage{config, ordered_device_ids, std::move(device_buffers), specs, /*mesh_buffer=*/nullptr}; return Tensor(std::move(storage), reference_shard.get_tensor_spec()); } } @@ -211,6 +211,11 @@ bool is_multi_device_tensor(const Tensor& tensor) { tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; } +bool is_mesh_buffer_tensor(const Tensor& tensor) { + auto* multi_device_storage = std::get_if(&tensor.get_storage()); + return multi_device_storage != nullptr && multi_device_storage->mesh_buffer != nullptr; +} + std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) { std::vector tensors; if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE) { @@ -263,7 +268,7 @@ Tensor create_multi_device_tensor( specs.insert({device_id, tensor.get_tensor_spec()}); } return Tensor{ - MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs, /*mesh_buffer_=*/nullptr}, + MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, specs, /*mesh_buffer=*/nullptr}, TensorSpec( tensors.at(0).get_logical_shape(), TensorLayout::fromPaddedShape( diff --git a/ttnn/cpp/ttnn/distributed/api.hpp b/ttnn/cpp/ttnn/distributed/api.hpp index 868aa553d73..da1758a16e2 100644 --- a/ttnn/cpp/ttnn/distributed/api.hpp +++ b/ttnn/cpp/ttnn/distributed/api.hpp @@ -45,6 +45,10 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) // Returns true has MultiDeviceHost/MultiDevice Storage bool is_multi_device_tensor(const Tensor& tensor); +// Returns true if tensor has MultiDevice storage type and is allocated on a mesh buffer. +// TODO: remove when the infrastructure uniformly works with mesh buffer backed tensors. +bool is_mesh_buffer_tensor(const Tensor& tensor); + // Given a multi-device tensor and a device, returns a list of per-device tensors. std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index a46e66ff35f..3d82d24714f 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -190,7 +190,7 @@ Tensor distribute_tensor( std::vector tensors = mapper.map(tensor); Tensor output = aggregate_as_tensor(tensors, mapper.config()); if (mesh_device.has_value()) { - return output.to(&(mesh_device->get())); + return output.to_device(&(mesh_device->get())); } return output; } diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp index 4be6e65ebe5..10b1dd22718 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp @@ -142,7 +142,7 @@ void kernel_main() { output_page_idx < output_tensor_shard_pages_per_shard_y * output_tensor_shard_pages_per_shard_x * output_tensor_shard_grid_height * output_tensor_shard_grid_width); #endif - write_chunk( + write_chunk_legacy( output_page_idx, col_idx, row_idx, @@ -165,7 +165,7 @@ void kernel_main() { output_page_idx < output_tensor_shard_pages_per_shard_y * output_tensor_shard_pages_per_shard_x * output_tensor_shard_grid_height * output_tensor_shard_grid_width); #endif - write_chunk( + write_chunk_legacy( output_page_idx, col_idx, row_idx, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp index 66a36d92c82..68d54579828 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp @@ -4,6 +4,7 @@ #pragma once #include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/sharding_addrgen.hpp" #include "debug/assert.h" #include "cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp" #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" @@ -166,6 +167,70 @@ FORCE_INLINE void write_chunk( uint32_t l1_read_addr = get_read_ptr(cb_id); int32_t contig_pages = 1; + for (int32_t pages_remaining = num_pages; pages_remaining != 0; pages_remaining -= contig_pages) { + contig_pages = 1; +#ifdef ROW_MAJOR_LAYOUT + #ifdef INTERLEAVED_MEM_LAYOUT + std::pair dst_noc_addr_retval = get_contiguous_noc_addr(output_page_idx, d); + uint64_t dst_noc_addr = dst_noc_addr_retval.first; + contig_pages = dst_noc_addr_retval.second; + noc_async_write(l1_read_addr, dst_noc_addr, page_size); + #elif defined SHARDED_MEM_LAYOUT + ASSERT(false); // untested && unimplemented + #endif + output_page_idx++; + row_idx++; + if (row_idx == num_rows) { + row_idx = 0; + output_page_idx += row_offset; + } +#elif defined TILED_LAYOUT + #ifdef INTERLEAVED_MEM_LAYOUT + noc_async_write_tile(output_page_idx, d, l1_read_addr); + #elif defined SHARDED_MEM_LAYOUT + std::pair dst_noc_addr_retval = get_contiguous_noc_addr(output_page_idx, d); + uint64_t dst_noc_addr = dst_noc_addr_retval.first; + contig_pages = std::min(pages_remaining, std::min(dst_noc_addr_retval.second, num_cols - col_idx)); + ASSERT(((dst_noc_addr >> 32) & 0xF) == 0); + + noc_async_write(l1_read_addr, dst_noc_addr, page_size * contig_pages); + #endif + output_page_idx += contig_pages; + col_idx += contig_pages; + if (col_idx == num_cols) { + output_page_idx += col_offset; + col_idx = 0; + row_idx++; + if (row_idx == num_rows) { + row_idx = 0; + output_page_idx += row_offset; + } + } +#endif + l1_read_addr += page_size * contig_pages; + } + noc_async_write_barrier(); + cb_pop_front(cb_id, num_pages); +} + + +template +FORCE_INLINE void write_chunk_legacy( + uint32_t& output_page_idx, + uint32_t& col_idx, + uint32_t& row_idx, + const uint32_t& cb_id, + const AddrGen& d, + const uint32_t& num_cols, + const uint32_t& num_rows, + const uint32_t& col_offset, + const uint32_t& row_offset, + const uint32_t& num_pages, + const uint32_t& page_size) { + cb_wait_front(cb_id, num_pages); + uint32_t l1_read_addr = get_read_ptr(cb_id); + int32_t contig_pages = 1; + for (int32_t pages_remaining = num_pages; pages_remaining != 0; pages_remaining -= contig_pages) { contig_pages = 1; #ifdef ROW_MAJOR_LAYOUT @@ -525,9 +590,26 @@ FORCE_INLINE void read_wrapped_chunk_from_output_tensor( } - template FORCE_INLINE std::pair get_noc_addr_and_contiguous_pages( + uint32_t curr_page_idx, + const uint32_t offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const AddrGen& address_generator, + const ttnn::ccl::Shape4D& tensor_slice_shape, + uint8_t noc_id = noc_index) { + constexpr uint32_t offset = 0; + std::pair ret_val = + get_contiguous_noc_addr(curr_page_idx,address_generator,offset,noc_id); + uint32_t flattened_offset_worker_slice = ttnn::ccl::v2::flattened_index(tensor_slice_shape, offset_worker_slice); + uint32_t contig_until_edge_of_tensor_slice = tensor_slice_shape.x - ((flattened_offset_worker_slice + offset_into_worker_slice) % tensor_slice_shape.x); + size_t contig_pages = std::min(ret_val.second, contig_until_edge_of_tensor_slice); + return {ret_val.first, contig_pages}; +} + + +template +FORCE_INLINE std::pair legacy_get_noc_addr_and_contiguous_pages( uint32_t curr_page_idx, const uint32_t offset_into_worker_slice, const ttnn::ccl::Shape4D& offset_worker_slice, @@ -573,6 +655,17 @@ FORCE_INLINE std::pair get_noc_addr_and_contiguous_pages_for curr_page_idx, offset_into_worker_slice, offset_worker_slice, address_generator, tensor_slice_shape, 0); } +template +FORCE_INLINE std::pair legacy_get_noc_addr_and_contiguous_pages_for_fabric_write( + uint32_t curr_page_idx, + const uint32_t offset_into_worker_slice, + const ttnn::ccl::Shape4D& offset_worker_slice, + const AddrGen& address_generator, + const ttnn::ccl::Shape4D& tensor_slice_shape) { + return legacy_get_noc_addr_and_contiguous_pages( + curr_page_idx, offset_into_worker_slice, offset_worker_slice, address_generator, tensor_slice_shape, 0); +} + namespace v2 { template FORCE_INLINE void write_wrapped_chunk( diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index 0b3f8b9640e..6951764459f 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -227,19 +227,13 @@ static bool shard_grid_is_transposed(Tensor const& t) { return shard_grid_transposed; } -static void emit_sharded_tensor_kernel_ct_args( - IDevice* d, - Tensor const& tensor, - std::vector& args, - std::size_t pages_per_shard_y, - std::size_t pages_per_shard_x) { +static void emit_sharded_tensor_kernel_ct_args(IDevice* d, const Tensor& tensor, std::vector& args) { std::ranges::copy( std::vector{static_cast(tensor.memory_config().memory_layout)}, std::back_inserter(args)); std::ranges::copy(ShardedAddrGenArgBuilder::emit_ct_args(tensor), std::back_inserter(args)); }; -static void log_sharded_tensor_kernel_args( - Tensor const& tensor, std::size_t pages_per_shard_y, std::size_t pages_per_shard_x, std::string const& prefix) { +static void log_sharded_tensor_kernel_args(const Tensor& tensor, const std::string& prefix) { ShardedAddrGenArgBuilder::log_sharded_tensor_kernel_args(tensor, prefix); } @@ -341,12 +335,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( uint32_t input_page_size = input_tensor_config->get_page_size(); uint32_t output_page_size = output_tensor_config->get_page_size(); - auto const& [input_pages_per_shard_y, input_pages_per_shard_x] = - is_sharded ? input_tensor.buffer()->shard_spec().shape_in_pages() : std::array{0, 0}; auto const& [output_pages_per_shard_y, output_pages_per_shard_x] = is_sharded ? output_tensor.buffer()->shard_spec().shape_in_pages() : std::array{0, 0}; if (is_sharded) { - TT_ASSERT(input_pages_per_shard_y > 0 && input_pages_per_shard_x > 0); TT_ASSERT(output_pages_per_shard_y > 0 && output_pages_per_shard_x > 0); log_trace(tt::LogOp, "input_buffer->page_size: {}", input_page_size); log_trace( @@ -501,14 +492,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(input_tensor_config->get_tile_size()), static_cast(output_tensor_config->get_tile_size())}; if (is_sharded) { - emit_sharded_tensor_kernel_ct_args( - device, input_tensor, worker_reader_sender_ct_args, input_pages_per_shard_y, input_pages_per_shard_x); - emit_sharded_tensor_kernel_ct_args( - device, - output_tensor, - worker_reader_sender_ct_args, - output_pages_per_shard_y, - output_pages_per_shard_x); + emit_sharded_tensor_kernel_ct_args(device, input_tensor, worker_reader_sender_ct_args); + emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_reader_sender_ct_args); }; log_trace(tt::LogOp, "Worker SR CT args"); @@ -520,8 +505,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( log_trace(tt::LogOp, "\tsender_worker_reader_semaphore_id: {}", sender_worker_reader_semaphore_id); if (is_sharded) { - log_sharded_tensor_kernel_args(input_tensor, input_pages_per_shard_y, input_pages_per_shard_x, "input"); - log_sharded_tensor_kernel_args(output_tensor, output_pages_per_shard_y, output_pages_per_shard_x, "output"); + log_sharded_tensor_kernel_args(input_tensor, "input"); + log_sharded_tensor_kernel_args(output_tensor, "output"); } return worker_reader_sender_ct_args; @@ -553,12 +538,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(output_tensor_config->get_tile_size())}; if (is_sharded) { - emit_sharded_tensor_kernel_ct_args( - device, - output_tensor, - worker_writer_sender_ct_args, - output_pages_per_shard_y, - output_pages_per_shard_x); + emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_writer_sender_ct_args); } log_trace(tt::LogOp, "Worker SW CT args"); log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram()); @@ -569,7 +549,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( log_trace(tt::LogOp, "\thalf_cb_num_pages: {}", max_pages_per_chunk); if (is_sharded) { - log_sharded_tensor_kernel_args(output_tensor, output_pages_per_shard_y, output_pages_per_shard_x, "output"); + log_sharded_tensor_kernel_args(output_tensor, "output"); } return worker_writer_sender_ct_args; }; @@ -623,12 +603,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(output_tensor_config->get_tile_size())}; if (is_sharded) { - emit_sharded_tensor_kernel_ct_args( - device, - output_tensor, - worker_writer_receiver_ct_args, - output_pages_per_shard_y, - output_pages_per_shard_x); + emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_writer_receiver_ct_args); } log_trace(tt::LogOp, "Worker RW ct args"); @@ -641,7 +616,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( log_trace(tt::LogOp, "\tfuse_op: {}", fuse_op); if (is_sharded) { - log_sharded_tensor_kernel_args(output_tensor, output_pages_per_shard_y, output_pages_per_shard_x, "output"); + log_sharded_tensor_kernel_args(output_tensor, "output"); } return worker_writer_receiver_ct_args; diff --git a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp index 72a48f32827..57eebb6f0d7 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.cpp @@ -1113,7 +1113,7 @@ void generate_multi_input_command_stream_kernel_rt_args( TT_FATAL(page_sizes.size() == tensors.size(), "Number of page sizes must match with the number of tensors"); auto command_stream_start_arg_indices = std::vector(num_command_streams, 0); std::vector rt_args; - rt_args.reserve(100); + rt_args.reserve(200); for (size_t i = 0; i < tensors.size(); i++) { if (tensors[i]) { if (fill_args_overrider) { @@ -1235,7 +1235,7 @@ void generate_multi_command_stream_kernel_rt_args( for (size_t i = 0; i < num_command_streams; i++) { std::ranges::copy( - ttnn::ccl::emit_address_generator_runtime_args(device, *tensors[i]), std::back_inserter(rt_args)); + ttnn::ccl::legacy_emit_address_generator_runtime_args(device, *tensors[i]), std::back_inserter(rt_args)); } // TODO: Handle teardown signalling @@ -1455,7 +1455,7 @@ std::vector CCLWorkerArgBuilder::generate_sender_reader_kernel_rt_args log_trace(tt::LogOp, "ccl_send_reader arg[{}]: page_size {}", logged_arg_idx, args[logged_arg_idx]); logged_arg_idx++; - auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, input_tensor); + auto const& addr_gen_rt_args = ttnn::ccl::legacy_emit_address_generator_runtime_args(this->device, input_tensor); std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); for (auto const& arg : addr_gen_rt_args) { log_trace(tt::LogOp, "ccl_send_reader arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]); @@ -1616,7 +1616,7 @@ std::vector CCLWorkerArgBuilder::generate_sender_writer_kernel_rt_args } } - auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, output_tensor); + auto const& addr_gen_rt_args = ttnn::ccl::legacy_emit_address_generator_runtime_args(this->device, output_tensor); std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); for (auto const& arg : addr_gen_rt_args) { log_trace(tt::LogOp, "ccl_send_writer arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]); diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp index 8fe14287998..370be920c8c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_reader_two_input.cpp @@ -57,118 +57,79 @@ constexpr uint32_t cb1_id = get_compile_time_arg_val(9); #endif #endif -struct sharded_addrgen_fields { - bool is_sharded = false; - uint8_t tensor_shard_grid_height = 0; - uint8_t tensor_shard_grid_width = 0; - uint8_t tensor_shard_grid_start_y_logical = 0; - uint8_t tensor_shard_grid_start_x_logical = 0; - uint32_t tensor_shard_pages_per_shard_y = 0; - uint32_t tensor_shard_pages_per_shard_x = 0; - bool tensor_shard_grid_transposed = 0; -}; - #ifdef TENSOR0_SHARDED_MEM_LAYOUT #ifdef SINGLE_TENSOR // SINGLE INPUT MODE - SHARDED -constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = { - true, - get_compile_time_arg_val(6), - get_compile_time_arg_val(7), - get_compile_time_arg_val(8), - get_compile_time_arg_val(9), - get_compile_time_arg_val(10), - get_compile_time_arg_val(11), - get_compile_time_arg_val(12) != 0}; + using Tensor0ShardInfo = ShardedInfo< + get_compile_time_arg_val(6), + get_compile_time_arg_val(7), + get_compile_time_arg_val(8), + get_compile_time_arg_val(9), + get_compile_time_arg_val(10), + get_compile_time_arg_val(11), + get_compile_time_arg_val(12)>; #else // TWO INPUT MODE -constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = { - true, - get_compile_time_arg_val(10), - get_compile_time_arg_val(11), - get_compile_time_arg_val(12), - get_compile_time_arg_val(13), - get_compile_time_arg_val(14), - get_compile_time_arg_val(15), - get_compile_time_arg_val(16) != 0}; + using Tensor0ShardInfo = ShardedInfo< + get_compile_time_arg_val(10), + get_compile_time_arg_val(11), + get_compile_time_arg_val(12), + get_compile_time_arg_val(13), + get_compile_time_arg_val(14), + get_compile_time_arg_val(15), + get_compile_time_arg_val(16)>; #endif -static_assert( - in0_sharded_addrgen_fields.tensor_shard_grid_height > 0, - "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_height\" was resolved to 0 but it " - "must not be 0."); -static_assert( - in0_sharded_addrgen_fields.tensor_shard_grid_width > 0, - "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_grid_width\" was resolved to 0 but it must " - "not be 0."); -static_assert( - in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, - "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but " - "it must not be 0."); -static_assert( - in0_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, - "Misconfigured sharded addrgen fields for tensor0. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but " - "it must not be 0."); +constexpr Tensor0ShardInfo test_object {}; +static_assert(test_object.number_of_cores > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"number_of_cores\" was resolved to 0 but it must not be 0."); +static_assert(test_object.page_size_jump > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"page_size_jump\" was resolved to 0 but it must not be 0."); +static_assert(test_object.pages_per_tensor_row > 0, "Misconfigured sharded addrgen fields for tensor0. Field \"pages_per_tensor_row\" was resolved to 0 but it must not be 0."); #else -constexpr sharded_addrgen_fields in0_sharded_addrgen_fields = {false, 0, 0, 0, 0, 0, 0, 0}; +using Tensor0ShardInfo = ShardedInfo<0,0,0,0,0,0,0>; #endif #ifndef SINGLE_TENSOR #if defined(TENSOR1_SHARDED_MEM_LAYOUT) #if defined(TENSOR0_SHARDED_MEM_LAYOUT) -constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = { - true, - get_compile_time_arg_val(17), - get_compile_time_arg_val(18), - get_compile_time_arg_val(19), - get_compile_time_arg_val(20), - get_compile_time_arg_val(21), - get_compile_time_arg_val(22), - get_compile_time_arg_val(23) != 0}; + using Tensor1ShardInfo = ShardedInfo< + get_compile_time_arg_val(17), + get_compile_time_arg_val(18), + get_compile_time_arg_val(19), + get_compile_time_arg_val(20), + get_compile_time_arg_val(21), + get_compile_time_arg_val(22), + get_compile_time_arg_val(23)>; #else // Then we are only consuming ct args for second operand and we resume from operation 8 -constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = { - true, - get_compile_time_arg_val(10), - get_compile_time_arg_val(11), - get_compile_time_arg_val(12), - get_compile_time_arg_val(13), - get_compile_time_arg_val(14), - get_compile_time_arg_val(15), - get_compile_time_arg_val(16) != 0}; + using Tensor1ShardInfo = ShardedInfo< + get_compile_time_arg_val(10), + get_compile_time_arg_val(11), + get_compile_time_arg_val(12), + get_compile_time_arg_val(13), + get_compile_time_arg_val(14), + get_compile_time_arg_val(15), + get_compile_time_arg_val(16)>; #endif -static_assert( - in1_sharded_addrgen_fields.tensor_shard_grid_height > 0, - "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_height\" was resolved to 0 but it " - "must not be 0."); -static_assert( - in1_sharded_addrgen_fields.tensor_shard_grid_width > 0, - "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_grid_width\" was resolved to 0 but it must " - "not be 0."); -static_assert( - in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_y > 0, - "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_y\" was resolved to 0 but " - "it must not be 0."); -static_assert( - in1_sharded_addrgen_fields.tensor_shard_pages_per_shard_x > 0, - "Misconfigured sharded addrgen fields for tensor1. Field \"tensor_shard_pages_per_shard_x\" was resolved to 0 but " - "it must not be 0."); +constexpr Tensor1ShardInfo test_object2 {}; +static_assert(test_object2.number_of_cores > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"number_of_cores\" was resolved to 0 but it must not be 0."); +static_assert(test_object2.page_size_jump > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"page_size_jump\" was resolved to 0 but it must not be 0."); +static_assert(test_object2.pages_per_tensor_row > 0, "Misconfigured sharded addrgen fields for tensor1. Field \"pages_per_tensor_row\" was resolved to 0 but it must not be 0."); #else -constexpr sharded_addrgen_fields in1_sharded_addrgen_fields = {0, 0, 0, 0, 0, 0, 0, 0}; +typedef ShardedInfo<0,0,0,0,0,0,0> Tensor1ShardInfo; #endif #endif template < tt::tt_metal::TensorMemoryLayout tensor_layout, tt::tt_metal::BufferType buffer_type, - tt::tt_metal::Layout page_layout> + tt::tt_metal::Layout page_layout, + typename ShardingInfoType> FORCE_INLINE auto build_source_address_generator( std::size_t& arg_idx, address_t tensor_address, std::size_t page_size, - const sharded_addrgen_fields& tensor_sharded_addrgen_fields, - uint32_t cb_id_in) -> typename source_tensor_addrgen::type { + uint32_t cb_id_in) { constexpr bool is_sharded = is_sharded_tensor_layout(tensor_layout); constexpr bool is_interleaved = tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED; constexpr bool is_tile_page_layout = page_layout == tt::tt_metal::Layout::TILE; @@ -178,50 +139,34 @@ FORCE_INLINE auto build_source_address_generator( "Only sharded and interleaved tensor layouts are supported but the unified address generator. A tensor layout " "not matching TensorMemoryLayout::WIDTH_SHARDED, TensorMemoryLayout::HEIGHT_SHARDED, " "TensorMemoryLayout::BLOCK_SHARDED, or TensorMemoryLayout::INTERLEAVED was specified."); - - using addrgen_type = typename source_tensor_addrgen::type; - bool addrgen_enabled = get_arg_val(arg_idx++) != 0; if constexpr (tensor_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED) { if constexpr (is_row_major_layout) { - return addrgen_type{.bank_base_address = tensor_address, .page_size = page_size}; + InterleavedAddrGen ret_val = { + .bank_base_address = tensor_address, .page_size = page_size}; + return ret_val; } else { - return addrgen_type{ + InterleavedAddrGenFast ret_val = { .bank_base_address = tensor_address, .page_size = page_size, .data_format = get_dataformat(cb_id_in)}; + return ret_val; } } else if constexpr ( tensor_layout == tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED || tensor_layout == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED || tensor_layout == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { - // We don't use these args at the moment but we keep them here for now to avoid a rewrite in the very - // near future where we'll want to support custom shard grid. - uint8_t input_shard_grid_nrows = 0; - uint8_t input_shard_grid_ncols = 0; - uint32_t* input_shard_grid_row_map = nullptr; - uint32_t* input_shard_grid_col_map = nullptr; - if (addrgen_enabled) { - input_shard_grid_nrows = get_arg_val(arg_idx++); - input_shard_grid_row_map = reinterpret_cast(get_arg_addr(arg_idx)); - arg_idx += input_shard_grid_nrows; - input_shard_grid_ncols = get_arg_val(arg_idx++); - input_shard_grid_col_map = reinterpret_cast(get_arg_addr(arg_idx)); - arg_idx += input_shard_grid_ncols; + const auto [mapping_table, rt_increment] = experimental::shard_addr_gen_utils::get_shard_map(get_arg_addr(arg_idx)); + if (addrgen_enabled) + { + arg_idx += rt_increment; } - - return tt::tt_metal::address_generators::build_sharded_addr_gen( - tt::tt_metal::address_generators::VirtualCoordWormholeWorkerToNocLookup(), - typename tt::tt_metal::address_generators::DeviceShardSpecTypeGetter::type( - tensor_sharded_addrgen_fields.tensor_shard_pages_per_shard_y, - tensor_sharded_addrgen_fields.tensor_shard_pages_per_shard_x, - tensor_sharded_addrgen_fields.tensor_shard_grid_height, - tensor_sharded_addrgen_fields.tensor_shard_grid_width, - tensor_sharded_addrgen_fields.tensor_shard_grid_start_y_logical, - tensor_sharded_addrgen_fields.tensor_shard_grid_start_x_logical, - tensor_sharded_addrgen_fields.tensor_shard_grid_transposed), - page_size, - tensor_address); + experimental::ShardedAddrGen ret_val = { + .bank_base_address = tensor_address, .shard_array=mapping_table}; + return ret_val; } else { ASSERT(false); + InterleavedAddrGen ret_val = { + .bank_base_address = tensor_address, .page_size = page_size}; + return ret_val; } } @@ -588,7 +533,7 @@ FORCE_INLINE void try_advance_read_tensor_to_cb(command_context_t& cmd_ uint32_t l1_write_addr = l1_write_addr_base; for (uint16_t i = 0; i < max_pages_readable; i += contig_pages_advanced) { - const auto [noc_addr, contig_pages_] = get_noc_addr_and_contiguous_pages( + const auto [noc_addr, contig_pages_] = get_noc_addr_and_contiguous_pages( cmd_specific_ctx.curr_tile_id, cmd_specific_ctx.offset_into_worker_slice, cmd_ctx.command_tensor.worker_start_offset_in_slice, @@ -748,13 +693,13 @@ FORCE_INLINE void try_advance_write_tensor_from_cb(command_context_t& c // However, if we're writing locally, then we need to actually write using `noc_index` based coordinates. // This can lead to a discrepancy, so to stay consistent, we always generate noc0 based addresses here // so we can reliably translate to `noc_index` based addresses writing locally, inside the write function - const auto [noc0_dest_noc_addr, contig_pages_] = - get_noc_addr_and_contiguous_pages_for_fabric_write( + auto const [noc0_dest_noc_addr, contig_pages_] = + get_noc_addr_and_contiguous_pages_for_fabric_write( cmd_specific_ctx.curr_tile_id, - cmd_specific_ctx.offset_into_worker_slice, - cmd_ctx.command_tensor.worker_start_offset_in_slice, - cmd_ctx.tensor_addrgen, - cmd_ctx.command_tensor.tensor_slice_shape); + cmd_specific_ctx.offset_into_worker_slice, + cmd_ctx.command_tensor.worker_start_offset_in_slice, + cmd_ctx.tensor_addrgen, + cmd_ctx.command_tensor.tensor_slice_shape); contig_pages_advanced = std::min(contig_pages_, max_pages_writable); contig_pages_advanced = std::min(cmd_ctx.packet_size_in_pages - i, contig_pages_); @@ -983,8 +928,9 @@ void kernel_main() { auto tensor0_addrgen = #ifndef NO_TENSOR_MODE - build_source_address_generator( - arg_idx, tensor_address0, tensor0_page_size, in0_sharded_addrgen_fields, cb0_id); + build_source_address_generator + + (arg_idx, tensor_address0, tensor0_page_size, cb0_id); #else no_addrgen{}; #endif @@ -992,8 +938,9 @@ void kernel_main() { #if !defined(SINGLE_INPUT_MODE) auto tensor1_addrgen = #if !defined(NO_TENSOR_MODE) && !defined(SINGLE_TENSOR) - build_source_address_generator( - arg_idx, tensor_address1, tensor1_page_size, in1_sharded_addrgen_fields, cb1_id); + build_source_address_generator + + (arg_idx, tensor_address1, tensor1_page_size, cb1_id); #else no_addrgen{}; #endif diff --git a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp index 1017c837583..9fe68098a7b 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/kernels/ccl_send_utils.hpp @@ -16,7 +16,7 @@ //------------------------------------------------------------------------------ template -std::pair get_noc_addr_and_contiguous_pages( +std::pair legacy_get_noc_addr_and_contiguous_pages( uint32_t curr_page_idx, const uint32_t offset_into_worker_slice, const ttnn::ccl::Shape4D& offset_worker_slice, @@ -61,13 +61,13 @@ std::pair get_noc_addr_and_contiguous_pages( } template -FORCE_INLINE std::pair get_noc_addr_and_contiguous_pages_for_fabric_write( +FORCE_INLINE std::pair legacy_get_noc_addr_and_contiguous_pages_for_fabric_write( uint32_t curr_page_idx, const uint32_t offset_into_worker_slice, const ttnn::ccl::Shape4D& offset_worker_slice, const AddrGen& address_generator, const ttnn::ccl::Shape4D& tensor_slice_shape) { - return get_noc_addr_and_contiguous_pages( + return legacy_get_noc_addr_and_contiguous_pages( curr_page_idx, offset_into_worker_slice, offset_worker_slice, address_generator, tensor_slice_shape, 0); } @@ -160,7 +160,7 @@ void mcast_payload_chunk_to_output_tensor_address( for (size_t i = 0; i < n_pages; i += contig_pages_advanced) { auto const [noc_addr, contig_pages] = - get_noc_addr_and_contiguous_pages_for_fabric_write( + legacy_get_noc_addr_and_contiguous_pages_for_fabric_write( curr_page_idx, offset_into_worker_slice, worker_slice_offset, diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp index 2d22f36ff2c..1fc8ee92045 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.cpp @@ -7,6 +7,7 @@ #include "cpp/ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" #include +#include "cpp/ttnn/operations/ccl/sharding_addrgen_helper.hpp" using namespace tt::tt_metal; @@ -24,7 +25,8 @@ args_list_t emit_runtime_args(WorkerEdmInterfaceArgs const& edm_interface_args) args_list_t emit_compile_time(WorkerEdmInterfaceArgs const& edm_interface_args) { return {}; } -args_list_t emit_address_generator_runtime_args(tt::tt_metal::IDevice const* const d, tt::tt_metal::Tensor const& t) { +args_list_t legacy_emit_address_generator_runtime_args( + const tt::tt_metal::IDevice* const d, const tt::tt_metal::Tensor& t) { args_list_t args; switch (t.buffer()->buffer_layout()) { case tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED: @@ -52,7 +54,33 @@ args_list_t emit_address_generator_runtime_args(tt::tt_metal::IDevice const* con }; } -args_list_t emit_address_generator_compile_time_args(tt::tt_metal::Tensor const& t) { +args_list_t emit_address_generator_runtime_args(const tt::tt_metal::IDevice* const d, const tt::tt_metal::Tensor& t) { + args_list_t args; + switch (t.buffer()->buffer_layout()) { + case tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED: + case tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED: + case tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED: return shard_builder::generate_run_time_args(t); break; + + case tt::tt_metal::TensorMemoryLayout::INTERLEAVED: + TT_ASSERT(t.buffer()->page_size() != 1024); + // For now we won't emit args for interleaved here... assume these are passed in elsewhere + // This is during some transitionary period + return {}; + + break; + + case tt::tt_metal::TensorMemoryLayout::SINGLE_BANK: + default: + TT_ASSERT( + false, + "Tried emitting address generator args for an unsupported type{}. Consider adding the missing support " + "or using a supported tensor memory layout (width sharded, height sharded, block sharded, interleaved", + t.buffer()->buffer_layout()); + return {}; + }; +} + +args_list_t legacy_emit_address_generator_compile_time_args(const tt::tt_metal::Tensor& t) { switch (t.buffer()->buffer_layout()) { case tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED: case tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED: @@ -72,6 +100,28 @@ args_list_t emit_address_generator_compile_time_args(tt::tt_metal::Tensor const& TT_ASSERT(false); } +args_list_t emit_address_generator_compile_time_args(const tt::tt_metal::Tensor& t) { + switch (t.buffer()->buffer_layout()) { + case tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED: + case tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED: + case tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED: + return shard_builder::generate_compile_time_args(t); + break; + + case tt::tt_metal::TensorMemoryLayout::INTERLEAVED: return {}; break; + + case tt::tt_metal::TensorMemoryLayout::SINGLE_BANK: + default: + TT_ASSERT( + false, + "Tried emitting address generator args for an unsupported type{}. Consider adding the missing support " + "or using a supported tensor memory layout (width sharded, height sharded, block sharded, interleaved", + t.buffer()->buffer_layout()); + return {}; + } + TT_ASSERT(false); +} + std::pair shard_grid_from_shard_spec(const ShardSpec& shard_spec) { auto const& core_range = shard_spec.grid.bounding_box(); log_trace( diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp index 74c5db70167..5c15a109d61 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp @@ -49,7 +49,10 @@ args_list_t emit_compile_time(Shape4D const& shape) { args_list_t emit_address_generator_runtime_args( tt::tt_metal::IDevice const* const d, tt::tt_metal::Tensor const& tensor); -args_list_t emit_address_generator_compile_time_args(tt::tt_metal::Tensor const& tensor); +args_list_t legacy_emit_address_generator_runtime_args( + const tt::tt_metal::IDevice* const d, const tt::tt_metal::Tensor& tensor); +args_list_t emit_address_generator_compile_time_args(const tt::tt_metal::Tensor& t); +args_list_t legacy_emit_address_generator_compile_time_args(const tt::tt_metal::Tensor& tensor); std::pair shard_grid_from_shard_spec(const tt::tt_metal::ShardSpec& shard_spec); diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/sharding_common.hpp b/ttnn/cpp/ttnn/operations/ccl/common/types/sharding_common.hpp new file mode 100644 index 00000000000..00feb494773 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/sharding_common.hpp @@ -0,0 +1,23 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// File contains enumerations that are common to both kernel and program factories with regards to sharding + +#pragma once + +#include + +namespace shard_addr_gen_consts { + +enum class ContiguityType { + // Indicates logical sharding placed padding between pages so no contiguous pages exist + PADDING_BETWEEN_PAGES = 0, + // Indicates some padding exists in the rightmost shard since the pages did not divide evenly into shards + PADDING_IN_RIGHTMOST_SHARD, + // Indicates no sharding based padding exists so all pages within the same shard are contiguous + // This is useful for height sharded tensors as multiple rows of the tensor can be contiguous. + NO_SHARD_PADDING, +}; + +} // namespace shard_addr_gen_consts diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/sharding_addrgen.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/sharding_addrgen.hpp new file mode 100644 index 00000000000..37fa871a969 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/sharding_addrgen.hpp @@ -0,0 +1,343 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +// The intent is to merge this file into dataflow_api.h and then refactor it into multiple files. +// It is currently here while its reliability is proven + +#pragma once +#if defined(KERNEL_BUILD) || defined(FW_BUILD) + +#include "dataflow_api.h" + +#endif + +#include "ttnn/cpp/ttnn/operations/ccl/common/types/sharding_common.hpp" + +using mapping_table_t = uint32_t; + +template < + uint32_t SHARD_TYPE, + uint32_t NUMBER_OF_CORES, + uint32_t PAGE_SIZE_JUMP, + uint32_t PAGES_PER_TENSOR_ROW, + uint32_t CONTIGUITY, + uint32_t PAGES_PER_SHARD_WIDTH, + uint32_t ROWS_PER_SHARD_HEIGHT> +struct ShardedInfo { +public: + // The isX types are correctly templated shard_grid_info class objects containing the information of the respective + // grid + constexpr static tt::tt_metal::TensorMemoryLayout shard_type = + static_cast(SHARD_TYPE); + constexpr static uint32_t number_of_cores = NUMBER_OF_CORES; + constexpr static uint32_t page_size_jump = PAGE_SIZE_JUMP; + constexpr static uint32_t pages_per_tensor_row = PAGES_PER_TENSOR_ROW; + constexpr static shard_addr_gen_consts::ContiguityType contiguity = + static_cast(CONTIGUITY); + constexpr static uint32_t pages_per_shard_width = PAGES_PER_SHARD_WIDTH; + constexpr static uint32_t rows_per_shard_height = ROWS_PER_SHARD_HEIGHT; +}; +namespace experimental { +namespace shard_addr_gen_utils { + +struct ShardCoordInfo { + uint32_t core_num; + uint32_t page_num; + uint32_t num_contiguous_pages; +}; + +template +struct ShardCoordInfo get_width_sharded_coordinates(uint32_t page_num) { + // Returns core index followed by the page number + struct ShardCoordInfo coord_info; + uint32_t page_row = page_num / total_pages_last_dim; + uint32_t page_col = page_num - page_row * total_pages_last_dim; + uint32_t w_core_id = page_col / columns_per_shard; + uint32_t w_offset = page_col - w_core_id * columns_per_shard; + coord_info.core_num = w_core_id; + coord_info.page_num = page_row * columns_per_shard + w_offset; + if constexpr (contiguity != shard_addr_gen_consts::ContiguityType::PADDING_BETWEEN_PAGES) { + uint32_t space_left_in_shard = columns_per_shard - w_offset; + uint32_t space_left_in_tensor = total_pages_last_dim - page_col; + coord_info.num_contiguous_pages = + space_left_in_shard < space_left_in_tensor ? space_left_in_shard : space_left_in_tensor; + } else { + coord_info.num_contiguous_pages = 1; + } + return coord_info; +} + +template +struct ShardCoordInfo get_height_sharded_coordinates(uint32_t page_num) { + // Returns core index followed by the page number + struct ShardCoordInfo coord_info; + constexpr uint32_t num_pages_per_core = total_pages_last_dim * rows_per_shard; + coord_info.core_num = page_num / num_pages_per_core; + coord_info.page_num = page_num - coord_info.core_num * num_pages_per_core; + if constexpr (contiguity == shard_addr_gen_consts::ContiguityType::PADDING_BETWEEN_PAGES) { + coord_info.num_contiguous_pages = 1; + } else if constexpr (contiguity == shard_addr_gen_consts::ContiguityType::PADDING_IN_RIGHTMOST_SHARD) { + coord_info.num_contiguous_pages = total_pages_last_dim - page_num % total_pages_last_dim; + } else { + coord_info.num_contiguous_pages = num_pages_per_core - coord_info.page_num; + } + return coord_info; +} + +template < + uint32_t columns_per_shard, + uint32_t rows_per_shard, + uint32_t total_pages_last_dim, + shard_addr_gen_consts::ContiguityType contiguity> +experimental::shard_addr_gen_utils::ShardCoordInfo get_block_sharded_coordinates(uint32_t page_num) { + // Returns core index followed by the page number + // Calculate how many cores are in the sharding grid + constexpr uint32_t cores_per_block_row = (total_pages_last_dim - 1) / columns_per_shard + 1; + experimental::shard_addr_gen_utils::ShardCoordInfo coord_info; + // Get row and column ID of this page + uint32_t page_row = page_num / total_pages_last_dim; + uint32_t page_col = page_num - page_row * total_pages_last_dim; + // Find the w direction core and the offset within it + uint32_t w_core_id = page_col / columns_per_shard; + uint32_t w_offset = page_col - w_core_id * columns_per_shard; + // Find the h direction core and the offset within it + uint32_t h_core_id = page_row / rows_per_shard; + uint32_t h_offset = page_row - h_core_id * rows_per_shard; + // Find the coord_info + coord_info.core_num = w_core_id + h_core_id * cores_per_block_row; + coord_info.page_num = w_offset + h_offset * columns_per_shard; + if constexpr (contiguity != shard_addr_gen_consts::ContiguityType::PADDING_BETWEEN_PAGES) { + uint32_t space_left_in_shard = columns_per_shard - w_offset; + uint32_t space_left_in_tensor = total_pages_last_dim - page_col; + coord_info.num_contiguous_pages = + space_left_in_shard < space_left_in_tensor ? space_left_in_shard : space_left_in_tensor; + } else { + coord_info.num_contiguous_pages = 1; + } + return coord_info; +} + +/* + * Returns a 16 bit compressed representation of the core number for each core in the shard grid + * ShardedAddrGen::get_sharded_addr can extract the noc address from this compressed core representation + * Representation is placed in a uint32_t array in a big endian ordering + */ + +template +std::pair get_shard_map(uint32_t L1_address) { + // Gets the shard_array from the runtime arguments + // returns a pair where .first holds the shard array map + // and .second holds the size of the map + constexpr ShardingInfoType CONSTANT_ARGS{}; + const mapping_table_t* const map = reinterpret_cast(L1_address); + constexpr uint32_t incrementation = (CONSTANT_ARGS.number_of_cores - 1) / 2 + 1; + return std::pair(map, incrementation); +} + +} // namespace shard_addr_gen_utils + +/* +* ShardedAddrGen requires the type definition of a ShardedInfo class object whose templates hold the CT information + ex. + typedef ShardedInfo < + SHARD_TYPE, + NUMBER_OF_CORES, + PAGE_SIZE_JUMP, + PAGES_PER_TENSOR_ROW, + CONTIGUITY, + PAGES_PER_SHARD_WIDTH, + ROWS_PER_SHARD_HEIGHT> tensor_1_shard_info; + + The above parameters are usually obtained using get_compile_time_arg_val. + In the program factory you can create an vector containing the above parameters in order using the function + shard_builder:generate_compile_time_args(const tt::tt_metal::Tensor& t) + defined in ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp + + It also needs a shard array map which can be extracted from the RT args using shard_addr_gen_utils::get_shard_map +function which requires the ShardedInfo class object ex. auto mapping = +experimental::shard_addr_gen_utils::get_shard_map(get_arg_addr(rt_index)); const mapping_table_t* +const shard_array_map = mapping.first; +//Contains the shard array map +rt_index += mapping.second;//contains the size of the map hence how much to increment the rt values + +In the program factory you can create an vector containing the runtime arguments extracted by this function using the +function shard_builder:generate_run_time_args(const tt::tt_metal::Tensor& t) + defined in ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp + + + + Lastly you need the bank_base_address from the Tensor object just like interleaved addr gen + + You can then create a sharded addrgen as follows: + s = ShardedAddrGen {.bank_base_address = bank_base_address, .shard_array=shard_array_map}; + This object can then be used by the get_noc_addr api. +*/ +template +struct ShardedAddrGen { + // Use this address generator for sharded tensors + + constexpr static ShardingInfoType CONSTANT_ARGS{}; + // Sharded Info Class is a ShardedInfo class object that is appropriately templated + // including all the compile time parameters + uint32_t bank_base_address; + const mapping_table_t* const shard_array; + + FORCE_INLINE + std::uint64_t get_sharded_addr( + const uint32_t l1_addr, const uint32_t sharding_coordinates, const uint32_t noc = noc_index) const { + // Extracts the X and Y value and using the l1 address gets the noc address + return NOC_XY_ADDR( + DYNAMIC_NOC_X(noc, ((sharding_coordinates >> 8) & 0xFF)), + DYNAMIC_NOC_Y(noc, (sharding_coordinates & 0xFF)), + l1_addr); + } + + std::uint32_t get_sharded_l1_addr(const uint32_t core_page, const uint32_t offset = 0) const { + // Get the L1 address + return this->bank_base_address + (core_page * CONSTANT_ARGS.page_size_jump) + offset; + } + + FORCE_INLINE + std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0, uint8_t noc = noc_index) const { + return this->get_contiguous_noc_addr(id, offset, noc).first; + } + FORCE_INLINE + std::pair get_contiguous_noc_addr( + const uint32_t id, const uint32_t offset = 0, uint8_t noc = noc_index) const { + // Returns the noc address AND the number of contiguous pages after. + + // Resolve linear core id/bank address, the page offset in the core, + // and the number of contiguous pages within that core + experimental::shard_addr_gen_utils::ShardCoordInfo sharding_coordinates{}; + if constexpr (CONSTANT_ARGS.shard_type == tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED) { + sharding_coordinates = experimental::shard_addr_gen_utils::get_width_sharded_coordinates< + CONSTANT_ARGS.pages_per_shard_width, + CONSTANT_ARGS.pages_per_tensor_row, + CONSTANT_ARGS.contiguity>(id); + } else if constexpr (CONSTANT_ARGS.shard_type == tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED) { + sharding_coordinates = experimental::shard_addr_gen_utils::get_height_sharded_coordinates< + CONSTANT_ARGS.rows_per_shard_height, + CONSTANT_ARGS.pages_per_tensor_row, + CONSTANT_ARGS.contiguity>(id); + } else { + sharding_coordinates = experimental::shard_addr_gen_utils::get_block_sharded_coordinates< + CONSTANT_ARGS.pages_per_shard_width, + CONSTANT_ARGS.rows_per_shard_height, + CONSTANT_ARGS.pages_per_tensor_row, + CONSTANT_ARGS.contiguity>(id); + } + // Get the value from the resolved core location containing the core x and y each 8 bits + // Note we are stripping this from a 32 bit array hence the floor division by 2 and in + // odd numbered cores a right shift by 16 and a masking + uint32_t sharding_coordinate_value = + (shard_array[(sharding_coordinates.core_num) >> 1] >> ((sharding_coordinates.core_num & 1) == 1 ? 0 : 16)) & + 0xFFFF; + // Find the L1 address within the resolved core + auto resolved_l1_addr = get_sharded_l1_addr(sharding_coordinates.page_num, offset); + // Find the noc address using the x,y core information + auto resolved_sharded_addr = get_sharded_addr(resolved_l1_addr, sharding_coordinate_value, noc); + // Return the core info and the number of contiguous cores + std::pair return_val(resolved_sharded_addr, sharding_coordinates.num_contiguous_pages); + return return_val; + } + + FORCE_INLINE + void noc_async_read_page( + const uint32_t id, const uint32_t dest_addr, const uint32_t offset = 0, uint8_t noc = noc_index) const { + noc_async_read(this->get_noc_addr(id, offset), dest_addr, CONSTANT_ARGS.page_size_jump, noc); + } +}; +} // namespace experimental + +/** + * gets the noc address from the addrgen object for a given page. + * This tells the user the address of the given page + * Can accept ShardedAddrGen, InterleavedAddrGen, InterleavedPow2AddrGen, + * InterleavedAddrGenFast, or InterleavedPow2AddrGenFast objects + * Return value: A uint64_t object with the noc address of the object + * + * | Argument | Description | Type | Valid Range | Required | + * |-------------------|---------------------------------------------------------|----------|----------------------------------------------------------------|----------| + * | id | The page or tile number to be accessed | uint32_t | 0..1MB | True | | + * | AddrGenObj | The address generator object to use | see above| N/A | True | | + * | offset | The offset within the page or tile to access | uint32_t | 0..page size| False | | + * | noc | Which noc to use, defaults to noc_index | uint8_t | 0 or 1 | False | | + */ + +template +FORCE_INLINE std::uint64_t get_noc_addr( + const uint32_t id, const experimental::ShardedAddrGen& s, uint32_t offset = 0, uint8_t noc = noc_index) { + return s.get_noc_addr(id, offset, noc); +} + +// Interleaved versions of get_noc_addr are implemented in dataflow_api.h + +/** + * gets the contiguous noc address from the addrgen object for a given page. + * This tells the user both the address of the given page and how many subsequent + * pages are contiguously located sequentially in the same memory. + * Can accept ShardedAddrGen, InterleavedAddrGen, InterleavedPow2AddrGen, + * InterleavedAddrGenFast, or InterleavedPow2AddrGenFast objects + * Return value: An std::pair object where the .first value is the noc address of the object + * and .second is the number of sequentially contiguous pages starting at id. + * + * | Argument | Description | Type | Valid Range | Required | + * |-------------------|---------------------------------------------------------|----------|----------------------------------------------------------------|----------| + * | id | The page or tile number to be accessed | uint32_t | 0..1MB | True | | + * | AddrGenObj | The address generator object to use | see above| N/A | True | | + * | offset | The offset within the page or tile to access | uint32_t | 0..page size| False | | + * | noc | Which noc to use, defaults to noc_index | uint8_t | 0 or 1 | False | | + */ + +template +FORCE_INLINE std::pair get_contiguous_noc_addr( + const uint32_t id, const experimental::ShardedAddrGen& s, uint32_t offset = 0, uint8_t noc = noc_index) { + return s.get_contiguous_noc_addr(id, offset, noc); +} + +#if defined(KERNEL_BUILD) || defined(FW_BUILD) + +template +FORCE_INLINE std::pair get_contiguous_noc_addr( + const uint32_t id, const InterleavedAddrGen& s, uint32_t offset = 0, uint8_t noc = noc_index) { + std::pair ret_val(s.get_noc_addr(id, offset, noc), 1); + return ret_val; +} + +template +FORCE_INLINE std::pair get_contiguous_noc_addr( + const uint32_t id, const InterleavedPow2AddrGen& s, uint32_t offset = 0, uint8_t noc = noc_index) { + std::pair ret_val(s.get_noc_addr(id, offset, noc), 1); + return ret_val; +} + +template +FORCE_INLINE std::pair get_contiguous_noc_addr( + const uint32_t id, const InterleavedAddrGenFast& s, uint32_t offset = 0, uint8_t noc = noc_index) { + std::pair ret_val(s.get_noc_addr(id, offset, noc), 1); + return ret_val; +} + +template +FORCE_INLINE std::pair get_contiguous_noc_addr( + const uint32_t id, const InterleavedPow2AddrGenFast& s, uint32_t offset = 0, uint8_t noc = noc_index) { + std::pair ret_val(s.get_noc_addr(id, offset, noc), 1); + return ret_val; +} + +#endif + +template +FORCE_INLINE void noc_async_read_page( + const uint32_t id, + const experimental::ShardedAddrGen& s, + std::uint32_t dst_local_l1_addr, + uint32_t offset = 0, + uint8_t noc = noc_index) { + /* + Read requests - use static VC + Read responses - assigned VCs dynamically + */ + s.noc_async_read_page(id, dst_local_l1_addr, offset, noc); +} diff --git a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp index 236e209e705..b374000953a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_utils.hpp @@ -50,6 +50,17 @@ FORCE_INLINE void fetch_chunk( cb_push_back(cb_id, num_pages); } +template +FORCE_INLINE void send_chunk_from_address_with_trid( + const uint32_t& local_l1_address, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_write_addr, uint8_t trid) { + noc_async_write_one_packet_with_trid(local_l1_address, remote_l1_write_addr, page_size * num_pages, trid); + if constexpr (blocking_mode == ttnn::ccl::EDM_IO_BLOCKING_MODE::FLUSH_BLOCKING) { + noc_async_writes_flushed(); + } else if constexpr (blocking_mode == ttnn::ccl::EDM_IO_BLOCKING_MODE::BLOCKING) { + noc_async_write_barrier(); + } +} + template FORCE_INLINE void send_chunk_from_address( const uint32_t& local_l1_address, const uint32_t& num_pages, const uint32_t& page_size, uint64_t remote_l1_write_addr) { diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp index 30aba536630..e6b2253c277 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_worker_adapters.hpp @@ -165,6 +165,9 @@ struct WorkerToFabricEdmSender { FORCE_INLINE void send_payload_non_blocking_from_address(uint32_t source_address, size_t size_bytes) { send_payload_from_address_impl(source_address, size_bytes); } + FORCE_INLINE void send_payload_non_blocking_from_address_with_trid(uint32_t source_address, size_t size_bytes, uint8_t trid) { + send_payload_from_address_with_trid_impl(source_address, size_bytes, trid); + } static constexpr size_t edm_sender_channel_field_stride_bytes = 16; @@ -243,7 +246,6 @@ struct WorkerToFabricEdmSender { FORCE_INLINE void update_edm_buffer_slot_wrptr() { uint64_t const noc_sem_addr = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_slot_wrptr_addr); - noc_inline_dw_write(noc_sem_addr, *this->buffer_slot_wrptr_ptr); } @@ -260,47 +262,60 @@ struct WorkerToFabricEdmSender { return wrptr - (normalize * this->num_buffers_per_channel); } - template - FORCE_INLINE void send_packet_header_and_notify_fabric(uint32_t source_address) { + FORCE_INLINE uint32_t compute_dest_buffer_slot_bank_address() const { + return this->edm_buffer_addr + (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + } - uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + FORCE_INLINE uint64_t compute_dest_buffer_slot_noc_addr() const { + return get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->compute_dest_buffer_slot_bank_address()); + } - send_chunk_from_address(source_address, 1, sizeof(tt::fabric::PacketHeader), buffer_address); + FORCE_INLINE void post_send_payload_increment_pointers() { this->advance_buffer_slot_wrptr(); this->update_edm_buffer_slot_wrptr(); } + template + FORCE_INLINE void send_packet_header_and_notify_fabric(uint32_t source_address) { + uint64_t buffer_address = this->compute_dest_buffer_slot_noc_addr(); + + send_chunk_from_address(source_address, 1, sizeof(tt::fabric::PacketHeader), buffer_address); + post_send_payload_increment_pointers(); + } + template FORCE_INLINE void send_payload_without_header_from_address_impl(uint32_t source_address, size_t size_bytes) { - uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + uint64_t buffer_address = this->compute_dest_buffer_slot_noc_addr(); // skip past the first part of the buffer which will be occupied by the packet header send_chunk_from_address(source_address, 1, size_bytes, buffer_address + sizeof(tt::fabric::PacketHeader)); } - template FORCE_INLINE void send_payload_from_address_impl(uint32_t source_address, size_t size_bytes) { - uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + uint64_t buffer_address = this->compute_dest_buffer_slot_noc_addr(); ASSERT(size_bytes <= this->buffer_size_bytes); ASSERT(tt::fabric::is_valid(*const_cast( reinterpret_cast(source_address)))); send_chunk_from_address(source_address, 1, size_bytes, buffer_address); + post_send_payload_increment_pointers(); + } + template + FORCE_INLINE void send_payload_from_address_with_trid_impl(uint32_t source_address, size_t size_bytes, uint8_t trid) { + uint64_t buffer_address = this->compute_dest_buffer_slot_noc_addr(); - this->advance_buffer_slot_wrptr(); - this->update_edm_buffer_slot_wrptr(); + ASSERT(size_bytes <= this->buffer_size_bytes); + ASSERT(tt::fabric::is_valid(*const_cast( + reinterpret_cast(source_address)))); + send_chunk_from_address_with_trid(source_address, 1, size_bytes, buffer_address, trid); + post_send_payload_increment_pointers(); } template FORCE_INLINE void send_payload_impl(uint32_t cb_id, uint32_t num_pages, uint32_t page_size) { - uint64_t buffer_address = get_noc_addr(this->edm_noc_x, this->edm_noc_y, this->edm_buffer_addr) + - (this->get_buffer_slot_index() * (this->buffer_size_bytes + sizeof(eth_channel_sync_t))); + uint64_t buffer_address = this->compute_dest_buffer_slot_noc_addr(); ASSERT(num_pages * page_size <= this->buffer_size_bytes); send_chunk(cb_id, num_pages, page_size, buffer_address); - this->advance_buffer_slot_wrptr(); - this->update_edm_buffer_slot_wrptr(); + post_send_payload_increment_pointers(); } }; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp index 8c4d073aeb8..28771d3e9e7 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp @@ -118,7 +118,10 @@ struct PacketHeader { CommandType command_type : 2; ChipSendType chip_send_type : 1; NocSendType noc_send_type : 1; - uint8_t reserved : 4; + // Used only by the EDM sender and receiver channels. Populated by EDM sender channel to + // indicate to the receiver channel what channel was the source of this packet. Reserved + // otherwise. + uint8_t src_ch_id : 4; RoutingFields routing_fields; uint16_t reserved2; // can be tagged with src device for debug @@ -261,6 +264,7 @@ struct PacketHeader { return this; } + inline void set_src_ch_id(uint8_t ch_id) volatile { this->src_ch_id = ch_id; } }; diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp index 272a1ca4d7d..edde4791916 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_transmission.hpp @@ -16,14 +16,6 @@ static constexpr size_t DESTINATION_HOP_COUNT = 1; // TODO: make 0 and the associated field to num mcast destinations static constexpr size_t LAST_MCAST_DESTINATION = 1; - -void write_unicast_blocking(uint32_t local_address, uint64_t dest_address, uint32_t size_bytes) { - // TODO - PERF: noc_async_write - // Don't do it yet because we want to sweep perf on buffer size - noc_async_write(local_address, dest_address, size_bytes); - noc_async_write_barrier(); -} - void print_pkt_hdr_routing_fields(volatile tt::fabric::PacketHeader *const packet_start) { switch (packet_start->chip_send_type) { case tt::fabric::CHIP_UNICAST: { @@ -77,7 +69,7 @@ void print_pkt_header(volatile tt::fabric::PacketHeader *const packet_start) { // Since we unicast to local, we must omit the packet header -void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start) { +void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const packet_start, uint32_t transaction_id) { auto const& header = *packet_start; uint32_t payload_start_address = reinterpret_cast(packet_start) + sizeof(tt::fabric::PacketHeader); @@ -94,7 +86,7 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const header.command_fields.unicast_write.noc_y, header.command_fields.unicast_write.address); auto const size = header.command_fields.unicast_write.size - sizeof(tt::fabric::PacketHeader); - write_unicast_blocking(payload_start_address, dest_address, size); + noc_async_write_one_packet_with_trid(payload_start_address, dest_address, size, transaction_id); }break; case tt::fabric::NocSendType::NOC_MULTICAST: { @@ -107,8 +99,7 @@ void execute_chip_unicast_to_local_chip(volatile tt::fabric::PacketHeader *const header.command_fields.mcast_write.address); auto const num_dests = header.command_fields.mcast_write.mcast_rect_size_x * header.command_fields.mcast_write.mcast_rect_size_y; auto const size = header.command_fields.mcast_write.size - sizeof(tt::fabric::PacketHeader); - noc_async_write_multicast_one_packet(payload_start_address, mcast_dest_address, size, num_dests); - noc_async_write_barrier(); + noc_async_write_one_packet_with_trid(payload_start_address, mcast_dest_address, size, num_dests, transaction_id); }break; default: { @@ -183,7 +174,8 @@ void update_packet_header_for_next_hop(volatile tt::fabric::PacketHeader * packe // !!!WARNING!!! void forward_payload_to_downstream_edm( volatile tt::fabric::PacketHeader *packet_header, - tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface + tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, + uint8_t transaction_id ) { DPRINT << "Fwding pkt to downstream\n"; // TODO: PERF - this should already be getting checked by the caller so this should be redundant make it an ASSERT @@ -192,9 +184,10 @@ void forward_payload_to_downstream_edm( // This is a good place to print the packet header for debug if you are trying to inspect packets // because it is before we start manipulating the header for forwarding update_packet_header_for_next_hop(packet_header); - downstream_edm_interface.send_payload_blocking_from_address( + downstream_edm_interface.send_payload_non_blocking_from_address_with_trid( reinterpret_cast(packet_header), - packet_header->get_payload_size_including_header()); + packet_header->get_payload_size_including_header(), + transaction_id); } diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp index 5e46f93e0e5..f296601f2a3 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover.cpp @@ -2,9 +2,6 @@ // // SPDX-License-Identifier: Apache-2.0 -#include -#include -#include #include "dataflow_api.h" #include "tt_metal/hw/inc/ethernet/dataflow_api.h" @@ -16,8 +13,13 @@ #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp" #include "cpp/ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "noc_overlay_parameters.h" + #include "ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/edm_fabric_counters.hpp" +#include +#include +#include using ttnn::ccl::WorkerXY; @@ -166,39 +168,6 @@ Packets 0, 2, and 3 are smaller than the full buffer size, while packet 1 is the buf 0 buf 1 buf 2 buf 3 -A detail of the channel structure is omitted from the above diagram, namely the EDM <-> EDM flow control region for each buffer. -Each buffer really looks something like this: - - - &header-> |----------------| channel_base_address - | header | - &payload-> |----------------| - | | - | payload | - | | - &channel_sync-> |----------------| - | channel_sync | // This is new - ------------------ - -The "channel_sync" is an `eth_channel_sync_t` and is internal to the EDM implementation and is used to indicate packet -transmission state between sender and receiver EDMs. - -The protocol for its use is: -1) Sender updates the field indicating new data: - - set `bytes_sent` to a non-zero value indicating new data - - clear `receiver_ack` to 0 - - set `src_id` to the sender channel id so the receiver knows who the sender was (and where the ack should go) -2) Sender sends this channel sync to the corresponding location in the receiver channel (either in the same transmission - as the packet or separately) -3) Receiver sees that `bytes_sent` is non-zero, indicating a new packet. It sends back an acknowledgement (first level): - - set `receiver_ack` to non-zero - *NOTE* IMPORTANT: To avoid a race, the receiver must be sure to send its channel_sync_t from a different address it uses - as for the second level acknowledgement - 3b) When sender receives an ack, it understands it can overwrite its local copy of the packet with new data -4) After receiver properly writes out its packet, it sends a second level acknowledgement, indicating it can receive new - data into this specific buffer index: - - clear the bytes_sent and receiver_ack fields and send back the `channel_sync` to the sender - ## Sending Packets @@ -216,6 +185,22 @@ The flow control protocol between EDM channels is built on a rd/wr ptr based pro to buffer slots within the channel (as opposed so something else like byte or word offset). Ptrs are free to advance independently from each other as long as there is no overflow or underflow. +The flow control is implemented through the use of several stream registers: one per conceptual pointer being tracked. +In total there are 5 such counters: +1) to receiver channel packets sent + - Incremented by sender (via eth_reg_write) by the number of buffer slots written. In practice, this means it is + incremented once per packet +2) to sender 0 packets acked + - Incremented by receiver for every new packet from channel 0 that it sees +3) to sender 1 packets acked + - Incremented by receiver for every new packet from channel 1 that it sees +4) to sender 0 packets completed + - Incremented by receiver for every packet from channel 0 that it completes processing for +5) to sender 1 packets completed + - Incremented by receiver for every packet from channel 1 that it completes processing for + +See calls to `increment_local_update_ptr_val`, `remote_update_ptr_val`, `init_ptr_val` for more on implementation. + ### Sender Channel Flow Control Both sender channels share the same flow control view into the receiver channel. This is because both channels write to the same receiver channel. @@ -257,6 +242,125 @@ write to the same receiver channel. // Data structures, types, enums, and constants //////////////////////////////////////////////// +// Transaction ID related constants/types +constexpr uint8_t NUM_TRANSACTION_IDS = 4; + +template +struct TransactionIdCounter { + void increment() { + this->next_trid = tt::fabric::wrap_increment(this->next_trid); + } + + uint8_t get() const { + return this->next_trid; + } + + private: + uint8_t next_trid = 0; +}; + +template +struct WriteTransactionIdTracker { + static constexpr uint8_t INVALID_TRID = MAX_TRANSACTION_IDS; + WriteTransactionIdTracker() { + for (size_t i = 0; i < NUM_CHANNELS; i++) { + this->buffer_slot_trids[i] = INVALID_TRID; + } + } + FORCE_INLINE void set_buffer_slot_trid(uint8_t trid, tt::fabric::BufferIndex buffer_index) { + this->buffer_slot_trids[buffer_index] = trid; + } + + FORCE_INLINE void advance_trid_counter() { + this->trid_counter.increment(); + } + + FORCE_INLINE uint8_t update_buffer_slot_to_next_trid_and_advance_trid_counter(tt::fabric::BufferIndex buffer_index) { + uint8_t next_trid = this->trid_counter.get(); + this->buffer_slot_trids[buffer_index] = next_trid; + this->trid_counter.increment(); + return next_trid; + } + + FORCE_INLINE void clear_trid_at_buffer_slot(tt::fabric::BufferIndex buffer_index) { + this->buffer_slot_trids[buffer_index] = INVALID_TRID; + } + + FORCE_INLINE uint8_t get_buffer_slot_trid(tt::fabric::BufferIndex buffer_index) const { + return this->buffer_slot_trids[buffer_index]; + } + FORCE_INLINE bool transaction_flushed(tt::fabric::BufferIndex buffer_index) const { + auto trid = this->get_buffer_slot_trid(buffer_index); + return trid == INVALID_TRID || ncrisc_noc_nonposted_write_with_transaction_id_flushed(noc_index, trid); + } + private: + std::array buffer_slot_trids; + TransactionIdCounter trid_counter; +}; + + +// senders update this stream +constexpr uint32_t to_receiver_pkts_sent_id = 0; +// receivers updates the reg on this stream +constexpr uint32_t to_sender_0_pkts_acked_id = 1; +// receivers updates the reg on this stream +constexpr uint32_t to_sender_1_pkts_acked_id = 2; +// receivers updates the reg on this stream +constexpr uint32_t to_sender_0_pkts_completed_id = 3; +// receivers updates the reg on this stream +constexpr uint32_t to_sender_1_pkts_completed_id = 4; + + +// This will be an atomic register read to the register +template +int32_t get_ptr_val() { + return NOC_STREAM_READ_REG(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX); + constexpr uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX); + return *reinterpret_cast(addr); +} +int32_t get_ptr_val(uint8_t stream_id) { + return NOC_STREAM_READ_REG(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX); + const uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_REG_INDEX); + return *reinterpret_cast(addr); +} + +// Writing to this register will leverage the built-in stream hardware which will automatically perform an atomic increment +// on the register. This can save precious erisc cycles by offloading a lot of pointer manipulation. +// Additionally, these registers are accessible via eth_reg_write calls which can be used to write a value, +// inline the eth command (without requiring source L1) +template +void increment_local_update_ptr_val(int32_t val) { + NOC_STREAM_WRITE_REG_FIELD(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX, REMOTE_DEST_BUF_WORDS_FREE_INC, val); +} +void increment_local_update_ptr_val(uint8_t stream_id, int32_t val) { + NOC_STREAM_WRITE_REG_FIELD(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX, REMOTE_DEST_BUF_WORDS_FREE_INC, val); +} + +template +void remote_update_ptr_val(int32_t val) { + constexpr uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX); + eth_write_remote_reg(addr, val << REMOTE_DEST_BUF_WORDS_FREE_INC); +} +void remote_update_ptr_val(uint32_t stream_id, int32_t val) { + const uint32_t addr = STREAM_REG_ADDR(stream_id, STREAM_REMOTE_DEST_BUF_SPACE_AVAILABLE_UPDATE_REG_INDEX); + eth_write_remote_reg(addr, val << REMOTE_DEST_BUF_WORDS_FREE_INC); +} + +template +void init_ptr_val(int32_t val) { + NOC_STREAM_WRITE_REG(stream_id, STREAM_REMOTE_DEST_BUF_SIZE_REG_INDEX, val); +} + +constexpr std::array to_sender_packets_acked_streams = {{ + to_sender_0_pkts_acked_id, + to_sender_1_pkts_acked_id +}}; + +constexpr std::array to_sender_packets_completed_streams = {{ + to_sender_0_pkts_completed_id, + to_sender_1_pkts_completed_id +}}; + /* * Tracks receiver channel pointers (from sender side) */ @@ -376,6 +480,7 @@ static constexpr size_t num_messages_to_move_ctor_value = 1; static constexpr size_t receiver_channel_id = NUM_SENDER_CHANNELS; static constexpr size_t worker_info_offset_past_connection_semaphore = 32; + ///////////////////////////////////////////// // SENDER SIDE HELPERS ///////////////////////////////////////////// @@ -402,18 +507,14 @@ void send_next_data( tt::fabric::EthChannelBuffer &sender_buffer_channel, tt::fabric::EdmChannelWorkerInterface &sender_worker_interface, OutboundReceiverChannelPointers &outbound_to_receiver_channel_pointers, - tt::fabric::EthChannelBuffer &receiver_buffer_channel) { + tt::fabric::EthChannelBuffer &receiver_buffer_channel, + uint8_t sender_channel_index) { auto &remote_receiver_wrptr = outbound_to_receiver_channel_pointers.wrptr; auto &local_sender_wrptr = sender_worker_interface.local_wrptr; auto local_sender_wrptr_buffer_index = local_sender_wrptr.get_buffer_index(); ASSERT(!eth_txq_is_busy()); - ASSERT( - reinterpret_cast(sender_buffer_channel.get_bytes_sent_address(local_sender_wrptr_buffer_index)) == - (reinterpret_cast(sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index)) + - reinterpret_cast(sender_buffer_channel.get_max_eth_payload_size()) - - (uint32_t)sizeof(eth_channel_sync_t))); // TODO: TUNING - experiment with only conditionally breaking the transfer up into multiple packets if we are // a certain threshold less than full packet @@ -421,12 +522,12 @@ void send_next_data( // compare // NOTE: if we always send full packet, then we don't need the second branch below dedicated for // channel sync - ASSERT(tt::fabric::is_valid(*const_cast(reinterpret_cast(sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index))))); - const size_t payload_size = sender_buffer_channel.get_payload_plus_channel_sync_size(local_sender_wrptr_buffer_index); - *sender_buffer_channel.get_bytes_sent_address(local_sender_wrptr_buffer_index) = payload_size; - *sender_buffer_channel.get_bytes_acked_address(local_sender_wrptr_buffer_index) = 0; - *sender_buffer_channel.get_src_id_address(local_sender_wrptr_buffer_index) = sender_buffer_channel.get_id(); - ASSERT(*sender_buffer_channel.get_src_id_address(local_sender_wrptr_buffer_index) < 2); + auto volatile *pkt_header = + reinterpret_cast(sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index)); + ASSERT(tt::fabric::is_valid(*const_cast(pkt_header))); + size_t payload_size = 0; + payload_size = pkt_header->get_payload_size_including_header(); + pkt_header->src_ch_id = sender_channel_index; auto src_addr = sender_buffer_channel.get_buffer_address(local_sender_wrptr_buffer_index); auto dest_addr = receiver_buffer_channel.get_buffer_address(remote_receiver_wrptr.get_buffer_index()); @@ -437,20 +538,13 @@ void send_next_data( payload_size, payload_size >> ETH_BYTES_TO_WORDS_SHIFT); - bool sent_payload_and_channel_sync_in_one_shot = - payload_size == sender_buffer_channel.get_max_eth_payload_size(); - if (!sent_payload_and_channel_sync_in_one_shot) { - // We weren't able to send the channel_sync_t in one shot with the payload so we need to send a second - // packet - // TODO: TUNING - consider busy waiting for a maximum amount of time - while (eth_txq_is_busy()) {} - send_channel_sync( - sender_buffer_channel, local_sender_wrptr, receiver_buffer_channel, remote_receiver_wrptr); - } // Note: We can only advance to the next buffer index if we have fully completed the send (both the payload and sync // messages) local_sender_wrptr.increment(); + // update the remote reg + static constexpr uint32_t words_to_forward = 1; + remote_update_ptr_val(words_to_forward); remote_receiver_wrptr.increment(); } @@ -476,38 +570,9 @@ void receiver_send_received_ack( // Set the acknowledgement bits. We have a different location than the auto receiver_buffer_index = receiver_channel_ptr.get_buffer_index(); - const auto src_id = *local_receiver_buffer_channel.get_src_id_address(receiver_buffer_index); - auto &sender_ackptr = remote_eth_sender_ackptrs[src_id]; - - ASSERT(src_id < NUM_SENDER_CHANNELS); - const size_t local_ack_channel_sync_src_addr = - local_receiver_buffer_channel.get_eth_transaction_ack_word_addr() + (src_id * sizeof(eth_channel_sync_t)); - reinterpret_cast(local_ack_channel_sync_src_addr)->bytes_sent = 1; // *local_receiver_buffer_channel.get_bytes_sent_address(); - reinterpret_cast(local_ack_channel_sync_src_addr)->receiver_ack = 1; - reinterpret_cast(local_ack_channel_sync_src_addr)->src_id = src_id; - reinterpret_cast(local_ack_channel_sync_src_addr)->reserved_2 = 0xc0ffee2; - - auto &sender_buffer_channel = remote_sender_channels[src_id]; - auto sender_ackptr_buffer_index = sender_ackptr.get_buffer_index(); - ASSERT(src_id < NUM_SENDER_CHANNELS); - ASSERT( - reinterpret_cast(sender_buffer_channel.get_bytes_sent_address(sender_ackptr_buffer_index)) == - reinterpret_cast(sender_buffer_channel.get_buffer_address(sender_ackptr_buffer_index)) + - reinterpret_cast(sender_buffer_channel.get_max_eth_payload_size()) - - sizeof(eth_channel_sync_t)); - // Make sure we don't alias the erisc_info eth_channel_sync_t - ASSERT( - reinterpret_cast(local_receiver_buffer_channel.get_bytes_sent_address(receiver_buffer_index)) - ->bytes_sent != 0); - ASSERT( - reinterpret_cast(local_receiver_buffer_channel.get_bytes_sent_address(receiver_buffer_index)) - ->receiver_ack == 0); - ASSERT(!eth_txq_is_busy()); - internal_::eth_send_packet_unsafe( - 0, - local_ack_channel_sync_src_addr >> 4, - ((uint32_t)(sender_buffer_channel.get_bytes_sent_address(sender_ackptr_buffer_index))) >> 4, - 1); + auto volatile *pkt_header = reinterpret_cast(local_receiver_buffer_channel.get_buffer_address(receiver_buffer_index)); + const auto src_id = pkt_header->src_ch_id; + remote_update_ptr_val(to_sender_packets_acked_streams[src_id], 1); } // MUST CHECK !is_eth_txq_busy() before calling @@ -519,25 +584,12 @@ FORCE_INLINE void receiver_send_completion_ack( tt::fabric::EthChannelBuffer &local_receiver_buffer_channel) { auto receiver_buffer_index = receiver_channel_ptr.get_buffer_index(); - volatile auto local_bytes_sent_addr = local_receiver_buffer_channel.get_bytes_sent_address(receiver_buffer_index); - volatile auto local_src_id_ptr = local_receiver_buffer_channel.get_src_id_address(receiver_buffer_index); - *(local_bytes_sent_addr) = 0; - *(local_receiver_buffer_channel.get_bytes_acked_address(receiver_buffer_index)) = 0; - - auto src_sender_channel = *local_src_id_ptr; - auto &remote_sender_channel = remote_sender_channels[src_sender_channel]; - auto &remote_sender_completion_ptr = remote_eth_sender_completion_ptrs[src_sender_channel]; - - ASSERT(src_sender_channel < NUM_SENDER_CHANNELS); - ASSERT(!eth_txq_is_busy()); - - internal_::eth_send_packet_unsafe( - 0, - (uint32_t)(local_bytes_sent_addr) >> 4, - (uint32_t)(remote_sender_channel.get_bytes_sent_address(remote_sender_completion_ptr.get_buffer_index())) >> 4, - 1); + auto volatile *pkt_header = reinterpret_cast(local_receiver_buffer_channel.get_buffer_address(receiver_buffer_index)); + const auto src_id = pkt_header->src_ch_id; + remote_update_ptr_val(to_sender_packets_completed_streams[src_id], 1); receiver_channel_ptr.increment(); + auto &remote_sender_completion_ptr = remote_eth_sender_completion_ptrs[src_id]; remote_sender_completion_ptr.increment(); } @@ -566,25 +618,25 @@ FORCE_INLINE bool can_forward_packet_completely( // !!!WARNING!!! - MAKE SURE CONSUMER HAS SPACE BEFORE CALLING void receiver_forward_packet( - volatile tt::fabric::PacketHeader *packet_start, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface) { + volatile tt::fabric::PacketHeader *packet_start, tt::fabric::WorkerToFabricEdmSender &downstream_edm_interface, uint8_t transaction_id) { // Just cache the packet_header - we don't really expect (or care) if contents change during this function. volatile tt::fabric::PacketHeader const &packet_header = *packet_start; ASSERT(tt::fabric::is_valid(const_cast(packet_header))); auto forward_status = get_packet_local_forward_type(packet_header); switch (forward_status) { case PACKET_FORWARD_LOCAL_ONLY: { - execute_chip_unicast_to_local_chip(packet_start); + execute_chip_unicast_to_local_chip(packet_start, transaction_id); } break; case PACKET_FORWARD_REMOTE_ONLY: { - forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + forward_payload_to_downstream_edm(packet_start, downstream_edm_interface, transaction_id); } break; case PACKET_FORWARD_LOCAL_AND_REMOTE: { ASSERT(packet_header.chip_send_type == tt::fabric::ChipSendType::CHIP_MULTICAST); // TODO: make local chip write non-blocking - execute_chip_unicast_to_local_chip(packet_start); - forward_payload_to_downstream_edm(packet_start, downstream_edm_interface); + execute_chip_unicast_to_local_chip(packet_start, transaction_id); + forward_payload_to_downstream_edm(packet_start, downstream_edm_interface, transaction_id); } break; case PACKET_FORWARD_INVALID: @@ -605,7 +657,6 @@ bool run_sender_channel_step( tt::fabric::EthChannelBuffer &remote_receiver_channel, volatile tt::fabric::EdmFabricSenderChannelCounters* sender_channel_counters, PacketHeaderRecorder &packet_header_recorder, - bool graceful_termination_mode, bool &channel_connection_established, uint8_t sender_channel_index) { bool did_something = false; @@ -613,81 +664,54 @@ bool run_sender_channel_step( // If the receiver has space, and we have one or more packets unsent from producer, then send one // TODO: convert to loop to send multiple packets back to back (or support sending multiple packets in one shot) // when moving to stream regs to manage rd/wr ptrs + // TODO: update to be stream reg based. Initialize to space available and simply check for non-zero bool receiver_has_space_for_packet = outbound_to_receiver_channel_pointers.has_space_for_packet(); if (receiver_has_space_for_packet && !eth_txq_is_busy()) { bool has_unsent_packet = local_sender_channel_worker_interface.has_unsent_payload(); if (has_unsent_packet) { bool sender_backpressured_from_sender_side = !(local_sender_channel_worker_interface.local_rdptr.distance_behind(local_sender_channel_worker_interface.local_wrptr) < SENDER_NUM_BUFFERS); if (!sender_backpressured_from_sender_side) { - ASSERT(local_sender_channel.eth_is_receiver_channel_send_done(local_sender_channel_worker_interface.local_wrptr.get_buffer_index())); did_something = true; auto packet_header = reinterpret_cast(local_sender_channel.get_buffer_address(local_sender_channel_worker_interface.local_wrptr.get_buffer_index())); - tt::fabric::validate(*packet_header); if constexpr (enable_packet_header_recording) { + tt::fabric::validate(*packet_header); packet_header_recorder.record_packet_header(packet_header); } + print_pkt_header(packet_header); send_next_data( local_sender_channel, local_sender_channel_worker_interface, outbound_to_receiver_channel_pointers, - remote_receiver_channel); + remote_receiver_channel, + sender_channel_index); } } } - bool has_unacknowledged_eth_packets = outbound_to_receiver_channel_pointers.has_unacknowledged_or_incomplete_eth_packets(); - if (has_unacknowledged_eth_packets) { - { - auto& sender_ackptr = local_sender_channel_worker_interface.local_ackptr; - auto old_ackptr = sender_ackptr; - // Only check for acks first - bool check_next = !local_sender_channel_worker_interface.all_eth_packets_acked(); - while (check_next) { - // TODO: change how ack is represented so we can check both at once without - // having to worry about races (i.e. right now we don't have monotonicity - // but if we did we could safely (check ack || completed)) - tt::fabric::BufferIndex rd_buffer_index = sender_ackptr.get_buffer_index(); - - bool acked_or_completed = local_sender_channel.eth_is_acked_or_completed(rd_buffer_index); - if (acked_or_completed) { - local_sender_channel.eth_clear_sender_channel_ack(rd_buffer_index); - sender_ackptr.increment(); - local_sender_channel_worker_interface.propagate_ackptr_to_connection_info(); - did_something = true; - outbound_to_receiver_channel_pointers.ack_ptr.increment(); - } - check_next = acked_or_completed && !local_sender_channel_worker_interface.all_eth_packets_acked(); - } + // Process COMPLETIONs from receiver + int32_t completions_since_last_check = get_ptr_val(to_sender_packets_completed_streams[sender_channel_index]); + if (completions_since_last_check > 0) { + auto& sender_rdptr = local_sender_channel_worker_interface.local_rdptr; + outbound_to_receiver_channel_pointers.completion_ptr.increment_n(completions_since_last_check); + sender_rdptr.increment_n(completions_since_last_check); + increment_local_update_ptr_val(to_sender_packets_completed_streams[sender_channel_index], -completions_since_last_check); + } - bool advanced = old_ackptr.get_ptr() != sender_ackptr.get_ptr(); - if (advanced && channel_connection_established) { - local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(); - } - } + // Process ACKs from receiver + // ACKs are processed second to avoid any sort of races. If we process acks second, + // we are guaranteed to see equal to or greater the number of acks than completions + auto acks_since_last_check = get_ptr_val(to_sender_packets_acked_streams[sender_channel_index]); - { - // stupid implementation but keeps things simple to bootstrap - auto& sender_rdptr = local_sender_channel_worker_interface.local_rdptr; - bool check_next = !local_sender_channel_worker_interface.all_eth_packets_completed(); - while (check_next) { - bool completed = local_sender_channel.eth_is_receiver_channel_send_done(sender_rdptr.get_buffer_index()); - if (completed) { - did_something = true; - if (local_sender_channel_worker_interface.local_ackptr.get_ptr() == sender_rdptr.get_ptr()) { - // If ackptr is also here, then we need to increment it too - outbound_to_receiver_channel_pointers.ack_ptr.increment(); - local_sender_channel_worker_interface.propagate_ackptr_to_connection_info(); - if (channel_connection_established) { - local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(); - } - } - outbound_to_receiver_channel_pointers.completion_ptr.increment(); - sender_rdptr.increment(); - } - check_next = completed && !local_sender_channel_worker_interface.all_eth_packets_completed(); - } + auto& sender_ackptr = local_sender_channel_worker_interface.local_ackptr; + if (acks_since_last_check > 0) { + sender_ackptr.increment_n(acks_since_last_check); + if (channel_connection_established) { + local_sender_channel_worker_interface.update_worker_copy_of_read_ptr(); } + increment_local_update_ptr_val(to_sender_packets_acked_streams[sender_channel_index], -acks_since_last_check); } + did_something = did_something || (completions_since_last_check + acks_since_last_check) > 0; + if (!channel_connection_established) { // Can get rid of one of these two checks if we duplicate the logic above here in the function @@ -727,16 +751,17 @@ void run_receiver_channel_step( std::array, NUM_SENDER_CHANNELS> &remote_eth_sender_wrptrs, ReceiverChannelPointers &receiver_channel_pointers, PacketHeaderRecorder &packet_header_recorder, + WriteTransactionIdTracker &receiver_channel_trid_tracker, ReceiverState *const receiver_state_out) { - // Optimization: - // 1. Let wrptr advance ahead of ackptr auto &ack_ptr = receiver_channel_pointers.ack_ptr; - auto ack_ptr_buffer_index = ack_ptr.get_buffer_index(); - bool packet_received = local_receiver_channel.eth_bytes_are_available_on_channel(ack_ptr_buffer_index) && - receiver_channel_pointers.completion_ptr.distance_behind(ack_ptr) < RECEIVER_NUM_BUFFERS; + auto pkts_received_since_last_check = get_ptr_val(); + bool pkts_received = pkts_received_since_last_check > 0; bool can_send_over_eth = !eth_txq_is_busy(); - if (packet_received && can_send_over_eth) { + ASSERT(receiver_channel_pointers.completion_ptr.distance_behind(ack_ptr) < RECEIVER_NUM_BUFFERS); + if (pkts_received && can_send_over_eth) { + // currently only support processing one packet at a time, so we only decrement by 1 + increment_local_update_ptr_val(-1); receiver_send_received_ack( remote_eth_sender_wrptrs, remote_sender_channnels, @@ -750,10 +775,12 @@ void run_receiver_channel_step( if (unwritten_packets) { auto receiver_buffer_index = wr_sent_ptr.get_buffer_index(); volatile auto packet_header = local_receiver_channel.get_packet_header(receiver_buffer_index); + print_pkt_header(packet_header); bool can_send_to_all_local_chip_receivers = can_forward_packet_completely(packet_header, downstream_edm_interface); if (can_send_to_all_local_chip_receivers) { - receiver_forward_packet(packet_header, downstream_edm_interface); + uint8_t trid = receiver_channel_trid_tracker.update_buffer_slot_to_next_trid_and_advance_trid_counter(receiver_buffer_index); + receiver_forward_packet(packet_header, downstream_edm_interface, trid); wr_sent_ptr.increment(); } } @@ -761,11 +788,12 @@ void run_receiver_channel_step( auto &wr_flush_ptr = receiver_channel_pointers.wr_flush_ptr; bool unflushed_writes = !wr_flush_ptr.is_caught_up_to(wr_sent_ptr); if (unflushed_writes) { - bool writes_flushed = ncrisc_noc_nonposted_writes_sent(noc_index); - if (writes_flushed) { - auto receiver_buffer_index = wr_flush_ptr.get_buffer_index(); + auto receiver_buffer_index = wr_flush_ptr.get_buffer_index(); + bool next_trid_flushed = receiver_channel_trid_tracker.transaction_flushed(receiver_buffer_index); + if (next_trid_flushed) { local_receiver_channel.eth_clear_sender_channel_ack(receiver_buffer_index); wr_flush_ptr.increment(); + receiver_channel_trid_tracker.clear_trid_at_buffer_slot(receiver_buffer_index); } } @@ -800,17 +828,22 @@ FORCE_INLINE bool got_termination_signal(volatile tt::fabric::TerminationSignal template bool all_channels_drained(tt::fabric::EthChannelBuffer &local_receiver_channel, std::array, NUM_SENDER_CHANNELS> &local_sender_channels, - std::array, NUM_SENDER_CHANNELS> &local_sender_channel_worker_interfaces) { + std::array, NUM_SENDER_CHANNELS> &local_sender_channel_worker_interfaces, + ReceiverChannelPointers &receiver_channel_pointers) { bool eth_buffers_drained = - !local_sender_channel_worker_interfaces[0].has_unacked_sends() && - !local_sender_channel_worker_interfaces[1].has_unacked_sends() && - local_receiver_channel.all_buffers_drained(); - - bool sender0_has_unsent_packets = local_sender_channel_worker_interfaces[0].has_unsent_payload(); - bool sender1_has_unsent_packets = local_sender_channel_worker_interfaces[1].has_unsent_payload(); - - return eth_buffers_drained && !sender0_has_unsent_packets && !sender1_has_unsent_packets; + local_sender_channel_worker_interfaces[0].all_eth_packets_completed() && + local_sender_channel_worker_interfaces[1].all_eth_packets_completed() && + !local_sender_channel_worker_interfaces[0].has_unsent_payload() && + !local_sender_channel_worker_interfaces[1].has_unsent_payload() && + receiver_channel_pointers.completion_ptr.is_caught_up_to(receiver_channel_pointers.ack_ptr) && + get_ptr_val() == 0 && + get_ptr_val() == 0 && + get_ptr_val() == 0 && + get_ptr_val() == 0 && + get_ptr_val() == 0; + + return eth_buffers_drained; } /* @@ -851,12 +884,14 @@ void run_fabric_edm_main_loop( ReceiverChannelPointers receiver_channel_pointers; std::array channel_connection_established = {false, false}; + WriteTransactionIdTracker receiver_channel_trid_tracker; + while (!got_immediate_termination_signal(termination_signal_ptr)) { bool got_graceful_termination = got_graceful_termination_signal(termination_signal_ptr); if (got_graceful_termination) { DPRINT << "EDM Graceful termination\n"; bool all_drained = all_channels_drained( - local_receiver_channel, local_sender_channels, local_sender_channel_worker_interfaces); + local_receiver_channel, local_sender_channels, local_sender_channel_worker_interfaces, receiver_channel_pointers); if (all_drained) { return; @@ -875,7 +910,6 @@ void run_fabric_edm_main_loop( remote_receiver_channel, sender_channel_counters_ptrs[sender_channel_index], sender_channel_packet_recorders[sender_channel_index], - got_graceful_termination, channel_connection_established[sender_channel_index], sender_channel_index); @@ -885,7 +919,9 @@ void run_fabric_edm_main_loop( local_receiver_channel, remote_sender_channels, downstream_edm_noc_interface, receiver_channel_counters_ptr, remote_eth_sender_wrptrs, receiver_channel_pointers, - receiver_channel_packet_recorder, &receiver_state); + receiver_channel_packet_recorder, + receiver_channel_trid_tracker, + &receiver_state); bool did_something = did_something_sender || old_recv_state != receiver_state; @@ -910,6 +946,15 @@ void kernel_main() { *reinterpret_cast(handshake_addr) = 0; auto eth_transaction_ack_word_addr = handshake_addr + sizeof(eth_channel_sync_t); + // Initialize stream register state for credit management across the Ethernet link. + // We make sure to do this before we handshake to guarantee that the registers are + // initialized before the other side has any possibility of modifying them. + init_ptr_val(0); + init_ptr_val(0); + init_ptr_val(0); + init_ptr_val(0); + init_ptr_val(0); + static constexpr size_t DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT = 0; if constexpr (is_handshake_sender) { erisc::datamover::handshake::sender_side_start(handshake_addr, DEFAULT_HANDSHAKE_CONTEXT_SWITCH_TIMEOUT); diff --git a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp index d7981f23407..a5d8298bbff 100644 --- a/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_erisc_datamover_channels.hpp @@ -11,6 +11,7 @@ #include "debug/dprint.h" #include "dataflow_api.h" #include "tt_metal/hw/inc/ethernet/tunneling.h" +#include "tt_metal/hw/inc/utils/utils.h" #include "risc_attribs.h" #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_packet_header.hpp" #include "cpp/ttnn/operations/ccl/kernels/edm_fabric/fabric_edm_types.hpp" @@ -42,7 +43,7 @@ using BufferPtr = NamedType; template auto wrap_increment(T val) -> T { static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); - constexpr bool is_pow2 = (LIMIT & (LIMIT - 1)) == 0; + constexpr bool is_pow2 = is_power_of_2(LIMIT); if constexpr (LIMIT == 1) { return val; } else if constexpr (LIMIT == 2) { @@ -53,6 +54,22 @@ auto wrap_increment(T val) -> T { return (val == static_cast(LIMIT - 1)) ? static_cast(0) : static_cast(val + 1); } } +template +auto wrap_increment_n(T val, uint8_t increment) -> T { + static_assert(LIMIT != 0, "wrap_increment called with limit of 0; it must be greater than 0"); + constexpr bool is_pow2 = is_power_of_2(LIMIT); + if constexpr (LIMIT == 1) { + return val; + } else if constexpr (LIMIT == 2) { + return 1 - val; + } else if constexpr (is_pow2) { + return (val + increment) & (LIMIT - 1); + } else { + T new_unadjusted_val = val + increment; + bool wraps = new_unadjusted_val >= LIMIT; + return wraps ? static_cast(new_unadjusted_val - LIMIT) : static_cast(new_unadjusted_val); + } +} template auto normalize_ptr(BufferPtr ptr) -> BufferIndex { @@ -113,6 +130,9 @@ class ChannelBufferPointer { return BufferIndex{normalize_ptr(this->ptr)}; } + void increment_n(uint8_t n) { + this->ptr = BufferPtr{wrap_increment_n<2*NUM_BUFFERS>(this->ptr.get(), n)}; + } void increment() { this->ptr = wrap_increment<2*NUM_BUFFERS>(this->ptr); } diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp index 7cada45446f..1d11547e123 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.cpp @@ -610,7 +610,7 @@ std::vector ReduceScatterWorkerArgBuilder::generate_line_start_sender_ std::ranges::copy(std::vector{this->op_config.get_page_size()}, std::back_inserter(args)); log_trace(tt::LogOp, "ccl_send arg[{}]: page_size {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; - auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_runtime_args(this->device, input_tensor); + auto const& addr_gen_rt_args = ttnn::ccl::legacy_emit_address_generator_runtime_args(this->device, input_tensor); std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); for (auto const& arg : addr_gen_rt_args) { log_trace(tt::LogOp, "ccl_send arg[{}]: addr_gen_rt_args[] {}", logged_arg_idx, args[logged_arg_idx]);logged_arg_idx++; @@ -639,7 +639,7 @@ std::vector ReduceScatterWorkerArgBuilder::generate_line_start_sender_ }; auto const& input_tensor = this->op_config.get_input_tensor(0); - auto const& addr_gen_rt_args = ttnn::ccl::emit_address_generator_compile_time_args(input_tensor); + auto const& addr_gen_rt_args = ttnn::ccl::legacy_emit_address_generator_compile_time_args(input_tensor); std::ranges::copy(addr_gen_rt_args, std::back_inserter(args)); return args; diff --git a/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp b/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp new file mode 100644 index 00000000000..1bb57fa6e51 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.cpp @@ -0,0 +1,179 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "cpp/ttnn/tensor/tensor.hpp" +#include "cpp/ttnn/operations/ccl/sharding_addrgen_helper.hpp" + +namespace shard_builder { + +uint32_t get_sharding_core_count(const tt::tt_metal::Tensor& t) { + uint32_t core_count = 0; + const auto core_ranges = t.buffer()->shard_spec().grid().ranges(); + for (uint32_t cr = 0; cr < core_ranges.size(); cr++) { + TT_FATAL( + core_ranges.at(cr).start_coord.x <= core_ranges.at(cr).end_coord.x, + "end coordinates left of start coordinates in shard"); + TT_FATAL( + core_ranges.at(cr).start_coord.y <= core_ranges.at(cr).end_coord.y, + "end coordinates above of start coordinates in shard"); + core_count += (core_ranges.at(cr).end_coord.x - core_ranges.at(cr).start_coord.x + 1) * + (core_ranges.at(cr).end_coord.y - core_ranges.at(cr).start_coord.y + 1); + } + return core_count; +} + +std::vector get_shard_cores(const tt::tt_metal::Tensor& t) { + std::vector coordinates; + const tt::tt_metal::IDevice* device = t.device(); + struct ShardSpec shard_spec = t.shard_spec().value(); + const auto core_ranges = t.buffer()->shard_spec().grid().ranges(); + bool shard_grid_transposed = + ((t.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED && + shard_spec.orientation == ShardOrientation::ROW_MAJOR) || + ((t.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || + t.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) && + shard_spec.orientation == ShardOrientation::COL_MAJOR)); + bool last = false; + uint32_t held_value = 0; + uint32_t concatenated_core = 0; + for (uint32_t cr = 0; cr < core_ranges.size(); cr++) { + TT_FATAL( + core_ranges.at(cr).start_coord.x <= core_ranges.at(cr).end_coord.x, + "end coordinates left of start coordinates in shard"); + TT_FATAL(core_ranges.at(cr).end_coord.x <= 0xFF, "sharding coordinates out of range"); + TT_FATAL( + core_ranges.at(cr).start_coord.y <= core_ranges.at(cr).end_coord.y, + "end coordinates above of start coordinates in shard"); + TT_FATAL(core_ranges.at(cr).end_coord.y <= 0xFF, "sharding coordinates out of range"); + if (shard_grid_transposed) { + for (uint32_t x_index = core_ranges.at(cr).start_coord.x; x_index <= core_ranges.at(cr).end_coord.x; + x_index++) { + for (uint32_t y_index = core_ranges.at(cr).start_coord.y; y_index <= core_ranges.at(cr).end_coord.y; + y_index++) { + CoreCoord noc_core = device->worker_core_from_logical_core(CoreCoord(x_index, y_index)); + coordinates.push_back(noc_core); + } + } + } else { + for (uint32_t y_index = core_ranges.at(cr).start_coord.y; y_index <= core_ranges.at(cr).end_coord.y; + y_index++) { + for (uint32_t x_index = core_ranges.at(cr).start_coord.x; x_index <= core_ranges.at(cr).end_coord.x; + x_index++) { + CoreCoord noc_core = device->worker_core_from_logical_core(CoreCoord(x_index, y_index)); + coordinates.push_back(noc_core); + } + } + } + } + return coordinates; +} + +std::vector generate_run_time_args(const tt::tt_metal::Tensor& t) { + std::vector args; + const tt::tt_metal::IDevice* device = t.device(); + struct ShardSpec shard_spec = t.shard_spec().value(); + const auto core_ranges = t.buffer()->shard_spec().grid().ranges(); + bool shard_grid_transposed = + ((t.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED && + shard_spec.orientation == ShardOrientation::ROW_MAJOR) || + ((t.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || + t.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED) && + shard_spec.orientation == ShardOrientation::COL_MAJOR)); + bool last = false; + uint32_t held_value = 0; + uint32_t concatenated_core = 0; + for (uint32_t cr = 0; cr < core_ranges.size(); cr++) { + TT_FATAL( + core_ranges.at(cr).start_coord.x <= core_ranges.at(cr).end_coord.x, + "end coordinates left of start coordinates in shard"); + TT_FATAL(core_ranges.at(cr).end_coord.x <= 0xFF, "sharding coordinates out of range"); + TT_FATAL( + core_ranges.at(cr).start_coord.y <= core_ranges.at(cr).end_coord.y, + "end coordinates above of start coordinates in shard"); + TT_FATAL(core_ranges.at(cr).end_coord.y <= 0xFF, "sharding coordinates out of range"); + if (shard_grid_transposed) { + for (uint32_t x_index = core_ranges.at(cr).start_coord.x; x_index <= core_ranges.at(cr).end_coord.x; + x_index++) { + for (uint32_t y_index = core_ranges.at(cr).start_coord.y; y_index <= core_ranges.at(cr).end_coord.y; + y_index++) { + CoreCoord noc_core = device->worker_core_from_logical_core(CoreCoord(x_index, y_index)); + concatenated_core = (noc_core.x & 0xFF) << 8 | (noc_core.y & 0xFF); + if (last) { + args.push_back(concatenated_core | (held_value << 16)); + } else { + held_value = concatenated_core; + } + last = !last; + } + } + } else { + for (uint32_t y_index = core_ranges.at(cr).start_coord.y; y_index <= core_ranges.at(cr).end_coord.y; + y_index++) { + for (uint32_t x_index = core_ranges.at(cr).start_coord.x; x_index <= core_ranges.at(cr).end_coord.x; + x_index++) { + CoreCoord noc_core = device->worker_core_from_logical_core(CoreCoord(x_index, y_index)); + concatenated_core = (noc_core.x & 0xFF) << 8 | (noc_core.y & 0xFF); + if (last) { + args.push_back(concatenated_core | (held_value << 16)); + } else { + held_value = concatenated_core; + } + last = !last; + } + } + } + } + if (last) { + args.push_back((held_value << 16)); + } + return args; +} + +void extend_sharding_run_time_args(const tt::tt_metal::Tensor& t, std::vector& args) { + const auto& new_args = generate_run_time_args(t); + std::copy(std::begin(new_args), std::end(new_args), std::back_inserter(args)); +} + +std::vector generate_compile_time_args(const tt::tt_metal::Tensor& t) { + std::vector args; + const tt::tt_metal::IDevice* device = t.device(); + TT_ASSERT(t.is_sharded()); + TT_FATAL( + t.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED || + t.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED || + t.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, + "ShardedAddrGenArgBuilder::emit_ct_args was invoked with a tensor containing an unsupported (Sharded) Tensor " + "Memory Layout: {}", + t.memory_config().memory_layout); + ShardSpec shard_spec = t.shard_spec().value(); + ShardSpecBuffer buf_shard_spec = t.buffer()->shard_spec(); + const auto& [pages_per_shard_y, pages_per_shard_x] = buf_shard_spec.shape_in_pages(); + // contiguity is 0 if there is padding between unaligned page, 1 if there is padding in the rightmost shard, and 2 + // otherwise + shard_addr_gen_consts::ContiguityType contiguity = + (t.buffer()->aligned_page_size() != t.buffer()->page_size()) + ? shard_addr_gen_consts::ContiguityType::PADDING_BETWEEN_PAGES + : (buf_shard_spec.tensor2d_shape[1] == (pages_per_shard_x * get_sharding_core_count(t))) + ? shard_addr_gen_consts::ContiguityType::NO_SHARD_PADDING + : shard_addr_gen_consts::ContiguityType::PADDING_IN_RIGHTMOST_SHARD; + args.push_back(static_cast(t.memory_config().memory_layout)); // Memory layout + args.push_back(static_cast(get_sharding_core_count(t))); // The number of sharding cores + args.push_back(static_cast(t.buffer()->aligned_page_size())); // The page size we offset each write to + TT_FATAL(t.buffer()->aligned_page_size() > 0, "aligned page size is 0"); + TT_FATAL(buf_shard_spec.tensor2d_shape[1] > 0, "the page is empty"); + args.push_back(static_cast( + buf_shard_spec.tensor2d_shape[1])); // The number of pages in each sharding row not including padding pages + args.push_back(static_cast(contiguity)); // This defines times when contiguous pages can't be calculated + args.push_back(pages_per_shard_x); + args.push_back(pages_per_shard_y); + return args; +} + +void extend_sharding_compile_time_args(const tt::tt_metal::Tensor& t, std::vector& args) { + const auto& new_args = generate_compile_time_args(t); + std::copy(std::begin(new_args), std::end(new_args), std::back_inserter(args)); +} + +} // namespace shard_builder diff --git a/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.hpp b/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.hpp new file mode 100644 index 00000000000..ab12f4a733b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/ccl/sharding_addrgen_helper.hpp @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "cpp/ttnn/operations/ccl/common/types/sharding_common.hpp" + +namespace shard_builder { +void extend_sharding_compile_time_args(const tt::tt_metal::Tensor& t, std::vector& args); +void extend_sharding_run_time_args(const tt::tt_metal::Tensor& t, std::vector& args); +std::vector generate_run_time_args(const tt::tt_metal::Tensor& t); +uint32_t get_sharding_core_count(const tt::tt_metal::Tensor& t); +std::vector generate_compile_time_args(const tt::tt_metal::Tensor& t); +std::vector get_shard_cores(const tt::tt_metal::Tensor& t); +} // namespace shard_builder diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index 824e12deb75..426f6e52151 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -361,9 +361,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( } auto grid_size = parallel_config.grid.bounding_box().grid_size(); - uint32_t act_c_num_blocks = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED ? 1 - : parallel_config.shard_orientation == ShardOrientation::COL_MAJOR ? grid_size.y - : grid_size.x; + uint32_t act_c_num_blocks = get_num_cores_channels_from_parallel_config(parallel_config); TT_ASSERT(padded_in_channels % act_c_num_blocks == 0); uint32_t act_block_w = parallel_config.shard_scheme == TensorMemoryLayout::HEIGHT_SHARDED @@ -995,9 +993,6 @@ conv_op_l1_usage conv2d::calculate_L1_usage( (per_core_out_matrix_height_ntiles + act_block_h_ntiles - 1) / act_block_h_ntiles; uint32_t out_block_h_ntiles_padded = num_blocks_act_h_per_core * act_block_h_ntiles; - // TODO: this needs to be changed - needs to be independent from dram alignment - const uint32_t alignment_bytes = std::max(hal.get_alignment(HalMemType::L1), hal.get_alignment(HalMemType::DRAM)); - TensorMemoryLayout sharding_scheme = conv_config.shard_layout.value(); if (sharding_scheme == TensorMemoryLayout::WIDTH_SHARDED) { uint32_t conv_output_c_per_core = per_core_out_matrix_width_ntiles * tt::constants::TILE_WIDTH; @@ -1078,7 +1073,7 @@ conv_op_l1_usage conv2d::calculate_L1_usage( } else if (conv_config.dtype == DataType::FLOAT32) { per_core_out_width_aligned *= 4; } - output_size = round_up(per_core_out_width_aligned, alignment_bytes) * pconfig.per_core_out_matrix_height; + output_size = round_up(per_core_out_width_aligned, hal.get_alignment(HalMemType::L1)) * pconfig.per_core_out_matrix_height; } else { output_size = per_core_out_matrix_height_ntiles * per_core_out_matrix_width_ntiles * output_tile_size; } @@ -1182,7 +1177,7 @@ conv_op_l1_usage conv2d::calculate_L1_usage( } else if (conv_config.dtype == DataType::FLOAT32) { per_core_out_width_aligned *= 4; } - output_size = round_up(per_core_out_width_aligned, alignment_bytes) * pconfig.per_core_out_matrix_height; + output_size = round_up(per_core_out_width_aligned, hal.get_alignment(HalMemType::L1)) * pconfig.per_core_out_matrix_height; } else { output_size = per_core_out_matrix_height_ntiles * per_core_out_matrix_width_ntiles * output_tile_size; } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index c6c072767a3..a7f1c2a774a 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -348,18 +348,22 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program( kernel_dims[1], sliding_window_config.get_output_shape()[2])); - TT_FATAL( - actual_cb_size == l1_usage.CB_allocation_size, - "Calculated CB size {} does not match with the actual CB size {}", - l1_usage.CB_allocation_size, - actual_cb_size); - - TT_FATAL( - post_op_l1_allocation_size == (this->pre_op_l1_allocation_size_bytes + l1_usage.tensor_allocation_size), - "Mismatch!! L1 Allocation Pre Op = {}, Post Op = {} Calculated Size = {}", - this->pre_op_l1_allocation_size_bytes, - post_op_l1_allocation_size, - l1_usage.tensor_allocation_size); + if (device->arch() != tt::ARCH::BLACKHOLE) { + // FIXME: This L1 calculation is not accurate for Blackhole due to different alignment. + // https://github.com/tenstorrent/tt-metal/issues/17216 + TT_FATAL( + actual_cb_size == l1_usage.CB_allocation_size, + "Calculated CB size {} does not match with the actual CB size {}", + l1_usage.CB_allocation_size, + actual_cb_size); + + TT_FATAL( + post_op_l1_allocation_size == (this->pre_op_l1_allocation_size_bytes + l1_usage.tensor_allocation_size), + "Mismatch!! L1 Allocation Pre Op = {}, Post Op = {} Calculated Size = {}", + this->pre_op_l1_allocation_size_bytes, + post_op_l1_allocation_size, + l1_usage.tensor_allocation_size); + } return program_with_cbs; } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/compute_depthwise_conv1d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/compute_depthwise_conv1d.cpp index 4fae5459cae..0c08a6edf0e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/compute_depthwise_conv1d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/compute_depthwise_conv1d.cpp @@ -71,7 +71,7 @@ inline void eltwise_mul_and_add_block_v2( cb_push_back(out_cb_id, 1); cb_pop_front(eltwise_mul_partials_cb_cb_id, 1); } else { - add_tiles_init(); + add_tiles_init(eltwise_mul_partials_cb_cb_id, out_cb_id); cb_wait_front(eltwise_mul_partials_cb_cb_id, 1); cb_wait_front(out_cb_id, 1); ACQ(); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 623dfb38f04..2678a4ce2af 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -69,7 +69,7 @@ Tensor create_tensor_from_owned_buffer( if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { auto tensor = Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, DataType::FLOAT32, Layout::ROW_MAJOR) - .to(Layout::TILE); + .to_layout(Layout::TILE); auto output_float_data = owned_buffer::get_as(tensor).get(); auto output_packed_data = output_dtype == DataType::BFLOAT8_B @@ -85,7 +85,7 @@ Tensor create_tensor_from_owned_buffer( "Unsupported output datatype"); } auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(buf)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); + return rm_tensor.to_layout(Layout::TILE); } template diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index cb7be8ae0b6..21d90d6cf46 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -53,10 +53,10 @@ ttnn::Tensor to_device( const ttnn::Tensor& tensor, IDevice* device, const std::optional& memory_config, uint8_t cq_id) { auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); if (mem_config.is_sharded() and (device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG, cq_id); + auto interleaved_tensor = tensor.to_device(device, ttnn::DRAM_MEMORY_CONFIG, cq_id); return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); } else { - return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG), cq_id); + return tensor.to_device(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG), cq_id); } } @@ -68,10 +68,10 @@ ttnn::Tensor to_device( auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); // Currently no direct sharded write support in BLACKHOLE due to alignment issue if (mem_config.is_sharded() and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG, cq_id); + auto interleaved_tensor = tensor.to_device(mesh_device, ttnn::DRAM_MEMORY_CONFIG, cq_id); return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); } else { - return tensor.to(mesh_device, mem_config, cq_id); + return tensor.to_device(mesh_device, mem_config, cq_id); } } diff --git a/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp b/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp index 1bfc7337d6a..ded9501cc3d 100644 --- a/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp +++ b/ttnn/cpp/ttnn/operations/core/to_dtype/to_dtype_op.hpp @@ -144,27 +144,27 @@ inline Tensor create_tensor_from_buffer( case DataType::UINT16: { auto data = cast(input_buffer); return create_owned_tensor(std::move(data), logical_shape, padded_shape, dtype, Layout::ROW_MAJOR) - .to(input_layout); + .to_layout(input_layout); } case DataType::INT32: { auto data = cast(input_buffer); return create_owned_tensor(std::move(data), logical_shape, padded_shape, dtype, Layout::ROW_MAJOR) - .to(input_layout); + .to_layout(input_layout); } case DataType::UINT32: { auto data = cast(input_buffer); return create_owned_tensor(std::move(data), logical_shape, padded_shape, dtype, Layout::ROW_MAJOR) - .to(input_layout); + .to_layout(input_layout); } case DataType::FLOAT32: { auto data = cast(input_buffer); return create_owned_tensor(std::move(data), logical_shape, padded_shape, dtype, Layout::ROW_MAJOR) - .to(input_layout); + .to_layout(input_layout); } case DataType::BFLOAT16: { auto data = cast<::bfloat16, T>(input_buffer); return create_owned_tensor(std::move(data), logical_shape, padded_shape, dtype, Layout::ROW_MAJOR) - .to(input_layout); + .to_layout(input_layout); } case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: { @@ -176,7 +176,7 @@ inline Tensor create_tensor_from_buffer( padded_shape, DataType::FLOAT32, Layout::ROW_MAJOR) - .to(Layout::TILE); + .to_layout(Layout::TILE); auto output_float_data = tt::tt_metal::owned_buffer::get_as(tensor).get(); auto output_packed_data = dtype == DataType::BFLOAT8_B @@ -244,7 +244,7 @@ struct ToDtype { return input_tensor; } - auto row_major_input_tensor = input_tensor.to(ttnn::ROW_MAJOR_LAYOUT); + auto row_major_input_tensor = input_tensor.to_layout(ttnn::ROW_MAJOR_LAYOUT); auto intermediate_tensor = distributed::is_multi_device_tensor(row_major_input_tensor) ? transform(row_major_input_tensor, detail::convert_to_cpp_supported_dtype) : detail::convert_to_cpp_supported_dtype(row_major_input_tensor); diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index aa52310c413..83fdad149f5 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -194,9 +194,9 @@ Tensor to_layout_impl( } else { TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting layout on host!"); if (not requires_padding_change(tensor, layout)) { - return device ? tensor.to(layout, device) : tensor.to(layout); + return device ? tensor.to_layout(layout, device) : tensor.to_layout(layout); } else if (layout == ttnn::ROW_MAJOR_LAYOUT) { - tensor = device ? tensor.to(layout, device) : tensor.to(layout); + tensor = device ? tensor.to_layout(layout, device) : tensor.to_layout(layout); tensor = tensor.unpad_from_tile(tensor.get_logical_shape()); return ttnn::reshape(tensor, ttnn::Shape{output_shape}); } else if (layout == ttnn::TILE_LAYOUT) { @@ -205,7 +205,7 @@ Tensor to_layout_impl( padded_input_start.push_back(0); } tensor = tensor.pad(ttnn::Shape(padded_output_shape), ttnn::Shape(std::move(padded_input_start)), 0); - tensor = device ? tensor.to(layout, device) : tensor.to(layout); + tensor = device ? tensor.to_layout(layout, device) : tensor.to_layout(layout); return ttnn::experimental::view(tensor, output_shape, padded_output_shape); } else { TT_THROW("ttnn::to_layout: Unsupported output layout: {}!", layout); diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index b6373f29228..80cd7e023ad 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -99,9 +99,9 @@ static Tensor arange_impl( auto output = Tensor( OwnedStorage{owned_buffer}, ttnn::Shape{1, 1, 1, static_cast(size)}, data_type, Layout::ROW_MAJOR) - .to(layout); + .to_layout(layout); if (device.has_value()) { - output = output.to(device->get_devices(), output_mem_config); + output = output.to_device(device->get_devices(), output_mem_config); } return output; } @@ -125,7 +125,7 @@ static Tensor full_impl( if (!optional_output_tensor.has_value()) { auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); if (!devices.empty()) { - output = output.to(devices, output_mem_config); + output = output.to_device(devices, output_mem_config); } return output; } else { diff --git a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp index 3ae2b66b4d1..397afbdfcdd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp @@ -177,4 +177,9 @@ FORCE_INLINE void transpose_2d( } } +template +FORCE_INLINE uint32_t align_address(const uint32_t address, const uint64_t mask) { + return (address & mask) + AlignReq; +} + } // namespace tt::data_movement::common diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp index f3872ff1581..4ed94d2be03 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_program_factory.cpp @@ -525,10 +525,10 @@ tt_metal::operation::ProgramWithCallbacks concat_multi_core( uint32_t num_output_pages; uint32_t single_page_size; + uint32_t common_align_len = std::max(input_tensors[0].buffer()->alignment(), output.buffer()->alignment()); if (rm_layout) { num_output_pages = output.volume() / output.get_padded_shape()[-1]; - single_page_size = - tt::align(output.element_size() * output.get_padded_shape()[-1], output.buffer()->alignment()); + single_page_size = tt::align(output.element_size() * output.get_padded_shape()[-1], common_align_len); } else { num_output_pages = output.volume() / TILE_HW; single_page_size = tt_metal::detail::TileSize(cb_data_format); diff --git a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp index 7161d3cca4d..d490b6ff7fe 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp @@ -72,7 +72,7 @@ void py_module(py::module& module) { detail::py_bind_bcast(module); detail::py_bind_copy(module); detail::py_bind_move(module); - detail::py_bind_expand(module); + py_bind_expand(module); py_bind_interleaved_to_sharded(module); py_bind_interleaved_to_sharded_partial(module); py_bind_repeat(module); diff --git a/ttnn/cpp/ttnn/operations/data_movement/data_transfer/data_transfer.cpp b/ttnn/cpp/ttnn/operations/data_movement/data_transfer/data_transfer.cpp index 837b54c32b9..cca84c20ed8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/data_transfer/data_transfer.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/data_transfer/data_transfer.cpp @@ -27,7 +27,7 @@ Tensor DataTransferToDeviceOperation::invoke( return {input_tensor}; } - return input_tensor.to(device, memory_config); + return input_tensor.to_device(device, memory_config); } } // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp deleted file mode 100644 index 4791243f488..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.cpp +++ /dev/null @@ -1,84 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "expand_device_operation.hpp" - -#include - -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/tensor/types.hpp" - -namespace ttnn::operations::expand { -ExpandOperation::program_factory_t ExpandOperation::select_program_factory( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& input = tensor_args.input; - - switch (input.get_layout()) { - case Layout::ROW_MAJOR: return ExpandRowMajorFactory{}; - default: TT_THROW("Expand: Unsupported input layout"); - } -} - -void validate( - const ExpandOperation::operation_attributes_t& operation_attributes, - const ExpandOperation::tensor_args_t& tensor_args) { - // We need to assert that the input and output are ROW_MAJOR. (unfortunately) - - const auto& input = tensor_args.input; - const auto& output = tensor_args.output; - - TT_FATAL(input.get_layout() == Layout::ROW_MAJOR, "Expand: Input tensor layout must be ROW_MAJOR"); - TT_FATAL(tensor_args.input.storage_type() == StorageType::DEVICE, "Expand: Input tensor need to be on device"); - TT_FATAL(tensor_args.input.buffer() != nullptr, "Expand: Input tensor need to be allocated in buffers on device"); - if (output.has_value()) { - TT_FATAL( - output->get_logical_shape() == operation_attributes.output_shape, - "Expand: Output shape must match operation attributes"); - TT_FATAL(input.get_layout() == output->get_layout(), "Expand: Input and output must have same layout"); - TT_FATAL(input.get_dtype() == output->get_dtype(), "Expand: Input and output must have same dtype"); - TT_FATAL(input.device() == output->device(), "Expand: Input and output must be on the same device"); - } -} - -void ExpandOperation::validate_on_program_cache_miss( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - validate(operation_attributes, tensor_args); -}; - -void ExpandOperation::validate_on_program_cache_hit( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - validate(operation_attributes, tensor_args); -}; - -ExpandOperation::spec_return_value_t ExpandOperation::compute_output_specs( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - if (tensor_args.output.has_value()) { - return tensor_args.output->get_tensor_spec(); - } - return TensorSpec( - Shape{operation_attributes.output_shape}, - TensorLayout( - tensor_args.input.get_dtype(), - PageConfig(tensor_args.input.get_layout()), - operation_attributes.memory_config)); -}; - -ExpandOperation::tensor_return_value_t ExpandOperation::create_output_tensors( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - // Let's just require it to be allocated ahead of time for now - if (tensor_args.output.has_value()) { - return {tensor_args.output.value()}; - } - - return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device()); -} - -std::tuple ExpandOperation::invoke( - const Tensor& input, - const SmallVector& output_shape, - const std::optional& output, - const std::optional& memory_config) { - return {{output_shape, memory_config.value_or(input.memory_config())}, {input, output}}; -} -} // namespace ttnn::operations::expand diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.hpp deleted file mode 100644 index 310c633b7e4..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_device_operation.hpp +++ /dev/null @@ -1,69 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include - -#include "ttnn/decorators.hpp" -#include "ttnn/device_operation.hpp" - -namespace ttnn::operations::expand { -struct ExpandOperation { - struct operation_attributes_t { - const SmallVector output_shape = {0}; - const MemoryConfig memory_config; - }; - - struct tensor_args_t { - const Tensor& input; - const std::optional& output; - }; - - using spec_return_value_t = TensorSpec; - using tensor_return_value_t = Tensor; - - struct ExpandRowMajorFactory { - struct shared_variables_t { - KernelHandle reader_kernel_id; - KernelHandle writer_kernel_id; - std::vector cores; - }; - - using cached_program_t = ttnn::device_operation::CachedProgram; - - static cached_program_t create( - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& output); - - static void override_runtime_arguments( - cached_program_t& cached_program, - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& output); - }; - - using program_factory_t = std::variant; - - static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); - static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); - static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); - static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); - - static std::tuple invoke( - const Tensor& input, - const SmallVector& output_shape, - - const std::optional& output, - const std::optional& memory_config); -}; -} // namespace ttnn::operations::expand - -namespace ttnn::prim { -constexpr auto expand = ttnn::register_operation<"ttnn::prim::expand", ttnn::operations::expand::ExpandOperation>(); -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_rm_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_rm_program_factory.cpp deleted file mode 100644 index 69dc8a539b8..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/device/expand_rm_program_factory.cpp +++ /dev/null @@ -1,213 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include -#include - -#include -#include -#include "expand_device_operation.hpp" -#include -#include "hostdevcommon/kernel_structs.h" -#include -#include -#include -#include "ttnn/tensor/types.hpp" - -using namespace tt::tt_metal; - -namespace ttnn::operations::expand { -ExpandOperation::ExpandRowMajorFactory::cached_program_t ExpandOperation::ExpandRowMajorFactory::create( - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& output) { - auto input = tensor_args.input; - - // Device Setup - auto* device = input.device(); - Program program = CreateProgram(); - - // Initialize data - const auto& input_shape_tmp = input.get_logical_shape(); - std::vector input_shape; - - // Strip empty leading dimensions (for what we are doing next, this spell P-A-I-N) - for (int i = 0; i < input_shape_tmp.size(); i++) { - if (input_shape_tmp[i] > 1) { - // Push the rest of the shape - for (int j = i; j < input_shape_tmp.size(); j++) { - input_shape.push_back(input_shape_tmp[j]); - } - break; - } - } - - // If it's LITERALLY {1}, handle it - if (input_shape.size() == 0) { - input_shape.push_back(1); - } - - const auto& output_shape = output.get_logical_shape(); - uint32_t data_size = input.element_size(); - tt::DataFormat data_format = datatype_to_dataformat_converter(input.get_dtype()); - - // These are needed for the 2d case where the page size actually changes - uint32_t input_tsr_rank = input_shape.size(); - uint32_t output_tsr_rank = output_shape.size(); - uint32_t n_rows = input_tsr_rank == 1 ? 1 : input_shape[input_tsr_rank - 2]; - - uint32_t unexpanded_row_size = input_shape[input_tsr_rank - 1] * data_size; - uint32_t expanded_row_size = output_shape[output_tsr_rank - 1] * data_size; - uint32_t horz_expand_count = expanded_row_size / unexpanded_row_size; - - uint32_t nd_expand_count = output.get_logical_volume() / input.get_logical_volume() / horz_expand_count; - -#ifdef DEBUG - tt::log_debug("Data size = %d\n", data_size); - - tt::log_debug("Input Page size = %lu\n", input.buffer()->page_size()); - tt::log_debug("Output Page size = %lu\n", output.buffer()->page_size()); - - std::stringstream debug_stream; - - debug_stream << "Input Shape = "; - for (auto i = 0; i < input_shape.size(); i++) { - debug_stream << input_shape[i] << " "; - } - debug_stream << std::endl; - - debug_stream << "Output Shape = "; - for (auto i = 0; i < output_shape.size(); i++) { - debug_stream << output_shape[i] << " "; - } - debug_stream << std::endl; - - tt::log_debug("%s", debug_stream.str().c_str()); - - tt::log_debug("Horz Expand Ratio = %d\n", horz_expand_count); - tt::log_debug("Vert Expand Ratio = %d\n", nd_expand_count); -#endif - - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - auto [num_cores, all_cores, core_group_1, core_group_2, num_copies_per_core_group_1, num_copies_per_core_group_2] = - split_work_to_cores(compute_with_storage_grid_size, nd_expand_count); - -#ifdef DEBUG - tt::log_debug("Num Cores = %d\n", num_cores); - tt::log_debug("Num Rows Per Core Group 1 = %d\n", num_copies_per_core_group_1); - tt::log_debug("Num Rows Per Core Group 2 = %d\n", num_copies_per_core_group_2); -#endif - - const auto src_is_dram = static_cast(input.buffer()->is_dram()); - const auto dst_is_dram = static_cast(output.buffer()->is_dram()); - - const auto sram_buffer_length = 32; - - // Scratch SRAM buffer - uint32_t scratch_buf_id = tt::CBIndex::c_24; - auto scratch_config = - CircularBufferConfig(unexpanded_row_size * sram_buffer_length, {{scratch_buf_id, data_format}}) - .set_page_size(scratch_buf_id, unexpanded_row_size); - auto scratch_handle = CreateCircularBuffer(program, all_cores, scratch_config); - - // IO SRAM Buffer - uint32_t io_buf_id = tt::CBIndex::c_16; - auto io_config = CircularBufferConfig(expanded_row_size * sram_buffer_length, {{io_buf_id, data_format}}) - .set_page_size(io_buf_id, expanded_row_size); - auto io_handle = CreateCircularBuffer(program, all_cores, io_config); - - std::vector reader_compile_runtime_args = { - src_is_dram, - scratch_buf_id, - io_buf_id, - data_size, - }; - std::vector writer_compile_runtime_args = { - dst_is_dram, - io_buf_id, - }; - - KernelHandle reader_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/reader_rm_expand.cpp", - all_cores, - ReaderDataMovementConfig(reader_compile_runtime_args)); - - KernelHandle writer_id = CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/writer_rm_expand.cpp", - all_cores, - WriterDataMovementConfig(writer_compile_runtime_args)); - - uint32_t rows_offset = 0; - uint32_t group1_cores = core_group_1.num_cores(); - auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y); - - uint32_t num_copies_this_core; - for (auto core : cores) { - if (core_group_1.contains(core)) { - num_copies_this_core = num_copies_per_core_group_1; - } else if (core_group_2.contains(core)) { - num_copies_this_core = num_copies_per_core_group_2; - } - - SetRuntimeArgs( - program, - reader_id, - core, - { - input.buffer()->address(), - n_rows, - input_shape[input_tsr_rank - 1], - horz_expand_count, - input.buffer()->page_size(), - }); - - SetRuntimeArgs( - program, - writer_id, - core, - { - output.buffer()->address(), - n_rows, - output.buffer()->page_size(), - num_copies_this_core, - rows_offset, - }); - - // Buffer page size is exactly one row in ROW_MAJOR mode - rows_offset += num_copies_this_core * n_rows; - } - return {std::move(program), {reader_id, writer_id, cores}}; -} - -void ExpandOperation::ExpandRowMajorFactory::override_runtime_arguments( - cached_program_t& cached_program, - const operation_attributes_t& operation_attributes, - const tensor_args_t& tensor_args, - tensor_return_value_t& output) { - auto& program = cached_program.program; - auto& reader_kernel_id = cached_program.shared_variables.reader_kernel_id; - auto& writer_kernel_id = cached_program.shared_variables.writer_kernel_id; - auto& cores = cached_program.shared_variables.cores; - - auto input = tensor_args.input; - - for (const auto& core : cores) { - { - // reader - auto runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = input.buffer()->address(); - } - { - // writer - auto runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = output.buffer()->address(); - } - } -} -} // namespace ttnn::operations::expand diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/reader_rm_expand.cpp b/ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/reader_rm_expand.cpp deleted file mode 100644 index 533a8a50178..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/reader_rm_expand.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include - -#include "dataflow_api.h" - -void kernel_main() { - std::uint32_t mem_buffer_src_addr = get_arg_val(0); - - std::uint32_t num_rows = get_arg_val(1); - std::uint32_t element_per_row = get_arg_val(2); - std::uint32_t horz_expand_count = get_arg_val(3); - - std::uint32_t dram_page_size = get_arg_val(4); - - constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t scratch_cb_id = get_compile_time_arg_val(1); - constexpr uint32_t io_cb_id = get_compile_time_arg_val(2); - constexpr uint32_t datasize_bytes = get_compile_time_arg_val(3); - - InterleavedAddrGen src_generator = { - .bank_base_address = mem_buffer_src_addr, - .page_size = dram_page_size, - }; - - cb_reserve_back(scratch_cb_id, 1); - auto tmp_buf = get_write_ptr(scratch_cb_id); - - for (uint32_t i = 0; i < num_rows; i++) { - cb_reserve_back(io_cb_id, 1); - - auto l1_addr = get_write_ptr(io_cb_id); - auto noc_addr = get_noc_addr(i, src_generator); - - // Read the entire row into scratch buffer - noc_async_read(noc_addr, tmp_buf, dram_page_size); - noc_async_read_barrier(); - - auto l1_ptr = reinterpret_cast(l1_addr); - auto tmp_buf_ptr = reinterpret_cast(tmp_buf); - - for (uint32_t k = 0; k < horz_expand_count; k++) { - memcpy(l1_ptr, tmp_buf_ptr, element_per_row * datasize_bytes); - l1_ptr += element_per_row * datasize_bytes; - } - - cb_push_back(io_cb_id, 1); - } -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/writer_rm_expand.cpp b/ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/writer_rm_expand.cpp deleted file mode 100644 index 3789baaf102..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/device/kernels/writer_rm_expand.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#include "dataflow_api.h" - -void kernel_main() { - std::uint32_t mem_buffer_dst_addr = get_arg_val(0); - - std::uint32_t num_rows = get_arg_val(1); - std::uint32_t dram_page_size = get_arg_val(2); - - std::uint32_t vert_expand_count = get_arg_val(3); - std::uint32_t skipped_pages = get_arg_val(4); - - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - constexpr uint32_t io_cb_id = get_compile_time_arg_val(1); - - InterleavedAddrGen dst_generator = { - .bank_base_address = mem_buffer_dst_addr, - .page_size = dram_page_size, - }; - - for (uint32_t i = 0; i < num_rows; i++) { - cb_wait_front(io_cb_id, 1); - auto l1_addr = get_read_ptr(io_cb_id); - for (uint32_t j = 0; j < vert_expand_count; j++) { - auto noc_addr = get_noc_addr(skipped_pages + j * num_rows + i, dst_generator); - noc_async_write(l1_addr, noc_addr, dram_page_size); - noc_async_write_barrier(); - } - cb_pop_front(io_cb_id, 1); - } -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/expand.cpp b/ttnn/cpp/ttnn/operations/data_movement/expand/expand.cpp index 2cf85de8442..7bc7afaadbb 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/expand.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/expand/expand.cpp @@ -2,81 +2,42 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/common/constants.hpp" +#include "ttnn/run_operation.hpp" #include "expand.hpp" - -#include - -#include "ttnn/operations/core/core.hpp" -#include "ttnn/operations/data_movement/expand/device/expand_device_operation.hpp" -#include "ttnn/tensor/tensor_impl.hpp" -#include "ttnn/tensor/tensor_impl_wrapper.hpp" -#include "ttnn/tensor/tensor_ops.hpp" +#include +#include +#include +#include "ttnn/operations/data_movement/repeat/repeat.hpp" namespace ttnn::operations::expand { -auto infer_size(const Tensor& input, const std::vector& sizes) { - const auto& input_shape = input.get_logical_shape(); - auto output_shape = SmallVector(sizes.size()); - TT_FATAL( - input_shape.rank() <= sizes.size(), - "Input tensor shape {}({}) must be at least as large as the expansion size {}({}), which it is not", - input_shape, - input_shape.rank(), - sizes, - sizes.size()); - - int in_idx = static_cast(input_shape.rank()) - 1; - for (int i = static_cast(output_shape.size()) - 1; i >= 0; --i) { - if (in_idx >= 0) { - TT_FATAL( - input_shape[in_idx] == sizes[i] || input_shape[in_idx] == 1 || sizes[i] == -1, - "The size of tensor a ({}) must match the size of tensor b ({}) at non-singleton dimension {}", - input_shape[in_idx], - sizes[i], - in_idx); - - if (input_shape[in_idx] == sizes[i] || sizes[i] == -1) { - output_shape[i] = input_shape[in_idx]; - } else if (input_shape[in_idx] == 1) { - output_shape[i] = sizes[i]; - } - --in_idx; +ttnn::SmallVector create_repetition_vector(const Tensor& tensor, std::span shape) { + ttnn::SmallVector expansion_vector(shape.size()); + auto tensor_shape = tensor.get_logical_shape(); + const auto source_rank = tensor_shape.rank(); + const auto new_rank = shape.size(); + TT_FATAL(source_rank <= new_rank, "Only size 1 dimensions can be expanded in the output shape"); + for (auto index = 0; index < new_rank; ++index) { + if (index >= source_rank) { + expansion_vector[index] = shape[index]; + } else if ((shape[index] == -1) || (shape[index] == tensor_shape[index])) { + expansion_vector[index] = 1; } else { - TT_FATAL(sizes[i] != -1, "The expanded size of the tensor (-1) is not allowed in a leading dimension"); - output_shape[i] = sizes[i]; + TT_FATAL(tensor_shape[index] == 1, "Only size 1 dimensions can be expanded in the output shape"); + expansion_vector[index] = shape[index]; } } - -#ifdef DEBUG - tt::log_debug("inferred output shape: "); - for (int i = 0; i < output_shape.size(); ++i) { - tt::log_debug("%d ", output_shape[i]); - } - tt::log_debug("\n"); -#endif - - return output_shape; + return expansion_vector; } -Tensor Expand::invoke( - const Tensor& input, - const std::vector& sizes, - - const std::optional& output, - const std::optional& memory_config) { - auto output_shape = infer_size(input, sizes); - - // Convert tile tensor to row major (lmfao) - if (input.get_layout() == Layout::TILE) { - // untilize/tilize is way too inaccurate for us to even remotely use. - Tensor rm_input_dev = core::to_device(input.cpu(true).to(Layout::ROW_MAJOR), input.device(), std::nullopt); - - Tensor rm_output_dev = ttnn::prim::expand(rm_input_dev, output_shape, std::nullopt, std::nullopt); - - return core::to_device( - rm_output_dev.cpu(true).pad_to_tile(0).to(Layout::TILE), rm_output_dev.device(), std::nullopt); - } - - return ttnn::prim::expand(input, output_shape, output, memory_config); +ttnn::Tensor ExpandOperation::invoke( + const ttnn::Tensor& tensor, + const tt::stl::Span shape_vector, + const std::optional& memory_config, + const std::optional& queue_id) { + const uint32_t queue_id_value = queue_id.value_or(0); + return ttnn::repeat(tensor, create_repetition_vector(tensor, shape_vector), memory_config, queue_id_value); } + } // namespace ttnn::operations::expand diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/expand.hpp b/ttnn/cpp/ttnn/operations/data_movement/expand/expand.hpp index 6d57516b7f2..b172769e54f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/expand.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/expand/expand.hpp @@ -8,17 +8,15 @@ #include "ttnn/decorators.hpp" namespace ttnn::operations::expand { -struct Expand { +struct ExpandOperation { static Tensor invoke( - const Tensor& input, - const std::vector& sizes, - - const std::optional& output, - const std::optional& memory_config); + const ttnn::Tensor& input, + const tt::stl::Span shape_vector, + const std::optional& memory_config, + const std::optional& queue_id); }; } // namespace ttnn::operations::expand namespace ttnn { -constexpr auto expand = - ttnn::register_operation_with_auto_launch_op<"ttnn::expand", ttnn::operations::expand::Expand>(); +constexpr auto expand = ttnn::register_operation<"ttnn::expand", ttnn::operations::expand::ExpandOperation>(); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp index 8b5c1bd7bc6..bfe4b5a357b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.cpp @@ -2,33 +2,60 @@ // // SPDX-License-Identifier: Apache-2.0 +#include +#include + +#include "cpp/pybind11/decorators.hpp" + +#include "expand.hpp" #include "expand_pybind.hpp" -#include "pybind11/decorators.hpp" -#include "ttnn/operations/data_movement/expand/expand.hpp" +namespace ttnn::operations::data_movement { +namespace py = pybind11; + +namespace detail { +template +void py_bind_expand(py::module& module, const data_movement_operation_t& operation, const char* doc) { + ttnn::bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const data_movement_operation_t& self, + const ttnn::Tensor& input_tensor, + const ttnn::SmallVector output_shape, + const std::optional& memory_config, + const uint8_t queue_id) { return self(input_tensor, output_shape, memory_config, queue_id); }, + py::arg("input_tensor"), + py::arg("output_shape"), + py::kw_only(), + py::arg("memory_config") = std::nullopt, + py::arg("queue_id") = 0, + }); +} + +} // namespace detail -namespace ttnn::operations::data_movement::detail { void py_bind_expand(py::module& module) { - const auto* doc = - R"doc(expand(input: ttnn.Tensor, sizes: List[int], output: Optional[ttnn.Tensor] = None, memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor + auto doc = + R"doc(expand(input: ttnn.Tensor, output_shape: List[int], memory_config: Optional[ttnn.MemoryConfig] = None) -> ttnn.Tensor Returns a new tensor where singleton dimensions are expanded to a larger side. Unlike :func:`torch.expand`, this function is not zero-cost and perform a memory copy to create the expanded tensor. This is due to `ttnn.Tensor`'s lack of strided tensor support. Args: * :attr:`input`: The tensor to be expanded. - * :attr:`sizes`: The desired expanded size. - * :attr:`output`: An optional tensor to store the expanded result. + * :attr:`output_shape`: The desired output shape. * :attr:`memory_config`: The memory configuration for the expanded tensor. + + Requirements: + like torch.expand: + only size 1 dimensions can be expanded in the output shape + -1 or the original shape size can be used to indicate that dimension should not have an expansion + The output shape must have the same or higher dimensions than the input shape + )doc"; - bind_registered_operation( - module, - ttnn::expand, - doc, - ttnn::pybind_arguments_t{ - py::arg("input"), - py::arg("sizes"), - py::kw_only(), - py::arg("output") = std::nullopt, - py::arg("memory_config") = std::nullopt}); + + detail::py_bind_expand(module, ttnn::expand, doc); } -} // namespace ttnn::operations::data_movement::detail + +} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.hpp index a4a4712a47b..656000a263f 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/expand/expand_pybind.hpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 @@ -6,8 +6,8 @@ #include "pybind11/pybind_fwd.hpp" -namespace py = pybind11; +namespace ttnn::operations::data_movement { -namespace ttnn::operations::data_movement::detail { -void py_bind_expand(py::module& module); -} // namespace ttnn::operations::data_movement::detail +void py_bind_expand(pybind11::module& module); + +} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/fold/device/fold_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/fold/device/fold_multi_core_program_factory.cpp index 3ab6978d611..369aa1a7118 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/fold/device/fold_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/fold/device/fold_multi_core_program_factory.cpp @@ -9,6 +9,7 @@ #include "fold_device_op.hpp" #include "ttnn/operations/math.hpp" +#include using namespace tt::tt_metal; @@ -39,7 +40,7 @@ Fold::MultiCore::cached_program_t fold_multi_core( // input CB uint32_t cb_src0_index = tt::CBIndex::c_0; - uint32_t aligned_pixel_size = round_up_to_mul32(pixel_size); + uint32_t aligned_pixel_size = tt::align(pixel_size, hal.get_alignment(HalMemType::L1)); auto src_cb_config = CircularBufferConfig(num_pixels * aligned_pixel_size, {{cb_src0_index, cb_data_format}}) .set_page_size(cb_src0_index, aligned_pixel_size) .set_globally_allocated_address(*input.buffer()); @@ -47,7 +48,7 @@ Fold::MultiCore::cached_program_t fold_multi_core( // output CB uint32_t cb_dst0_index = tt::CBIndex::c_16; - uint32_t aligned_dst_pixel_size = round_up_to_mul32(dst_pixel_size); + uint32_t aligned_dst_pixel_size = tt::align(dst_pixel_size, hal.get_alignment(HalMemType::L1)); auto dst_cb_config = CircularBufferConfig(num_dst_pixels * aligned_dst_pixel_size, {{cb_dst0_index, cb_data_format}}) .set_page_size(cb_dst0_index, aligned_dst_pixel_size) diff --git a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp index 7ddce80cc97..a009d7d00aa 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/pad/device/pad_program_factory.cpp @@ -46,7 +46,8 @@ operation::ProgramWithCallbacks pad_rm_reader_writer( ttnn::Shape({1, 1, 1, pad_value_const_buffer_size}), DataType::BFLOAT16, Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); + .to_device( + device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); Buffer* src0_buffer = a.buffer(); @@ -477,7 +478,8 @@ operation::ProgramWithCallbacks pad_rm_reader_writer_multi_core( ttnn::Shape({1, 1, 1, pad_value_const_buffer_size}), DataType::BFLOAT16, Layout::ROW_MAJOR) - .to(device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); + .to_device( + device, MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::L1}); auto pad_value_const_tensor_addr = pad_value_const_tensor.buffer()->address(); // uint32_t ntiles_h = output_tensor_shape[0] * output_tensor_shape[1] * output_tensor_shape[2] / TILE_HEIGHT; @@ -1435,14 +1437,12 @@ operation::ProgramWithCallbacks pad_rm_sharded_width_only( TT_THROW("ttnn.pad: unsupported data type for pad_rm_sharded_stickwise"); } - // FIXME: assumes that this was sharded using DRAM alignment so that gaps are left in the tensor. - // if this changes, we should change the stick step to be 16B (L1 alignment). - auto dram_alignment_bytes = tt::tt_metal::hal.get_alignment(tt::tt_metal::HalMemType::DRAM); + auto l1_alignment_bytes = tt::tt_metal::hal.get_alignment(tt::tt_metal::HalMemType::L1); uint32_t padded_stick_step = tt::round_up( - padded_stick_bytes, dram_alignment_bytes); // round padded_stick bytes to a multiple of dram_alignment_bytes + padded_stick_bytes, l1_alignment_bytes); // round padded_stick bytes to a multiple of l1_alignment_bytes uint32_t unpadded_stick_step = tt::round_up( unpadded_stick_bytes, - dram_alignment_bytes); // round unpadded_stick bytes to a multiple of dram_alignment_bytes + l1_alignment_bytes); // round unpadded_stick bytes to a multiple of l1_alignment_bytes std::vector reader_ct_args = { unpadded_stick_bytes, diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_higher_dim_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_higher_dim_rm.cpp new file mode 100644 index 00000000000..1d62e9418ee --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_higher_dim_rm.cpp @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +using namespace tt::data_movement::common; + +void kernel_main() { + // We are guranteed to be in 4D going to 4D + // + + const uint32_t src_addr = get_arg_val(0); + const uint32_t dst_addr = get_arg_val(1); + // Program factory can control the start and end of each of the 3 dims + const uint32_t higher_dim_start = get_arg_val(2); + const uint32_t higher_dim_end = get_arg_val(3); + const uint32_t lower_dim_start = get_arg_val(4); + const uint32_t lower_dim_end = get_arg_val(5); + const uint32_t repetitions = get_arg_val(6); + // nop lets you intentionally not use this core if the dims don't divide nicely + const uint32_t nop = get_arg_val(7); + + constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t original_page_size_bytes = get_compile_time_arg_val(1); + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(2); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(3); +#define page_is_pow_2 (get_compile_time_arg_val(4) == 1) + constexpr uint32_t page_pow_2 = get_compile_time_arg_val(5); + //(higher_dim,rep_dim,lower_dim,page_size) + // cb_id_in0 and cb_id_in1 is each 1 page of size: + // 128 + page size in bytes + constexpr uint32_t LOWER_DIMS = get_compile_time_arg_val(6); + constexpr uint32_t REP_DIM = get_compile_time_arg_val(7); + + constexpr uint32_t LOWER_DIMS_TIMES_REP_DIM = LOWER_DIMS * REP_DIM; + + // Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this + // core + if (nop == 1) { + return; + } + +#if page_is_pow_2 + // TODO: add CCL sharded native support + const InterleavedPow2AddrGen s = { + .bank_base_address = src_addr, .log_base_2_of_page_size = page_pow_2}; + const InterleavedPow2AddrGen d = { + .bank_base_address = dst_addr, .log_base_2_of_page_size = page_pow_2}; +#else + const InterleavedAddrGen s = {.bank_base_address = src_addr, .page_size = original_page_size_bytes}; + const InterleavedAddrGen d = {.bank_base_address = dst_addr, .page_size = original_page_size_bytes}; +#endif + + // alignments pre-calculations + constexpr uint64_t r_mask_to_use = tensor_is_dram ? MASK_64 : MASK_16; + constexpr uint64_t r_offset_to_use = tensor_is_dram ? OFFSET_64 : OFFSET_16; + + constexpr uint32_t r_alignment_requirement = tensor_is_dram ? 64 : 16; + constexpr uint32_t w_alignment_requirement = 16; + const uint64_t w_mask_to_use = MASK_16; + const uint64_t w_offset_to_use = OFFSET_16; + + cb_reserve_back(cb_id_in0, 1); + cb_reserve_back(cb_id_in1, 1); + uint32_t input_buffer = get_write_ptr(cb_id_in0); + uint32_t alignment_buffer = get_write_ptr(cb_id_in1); + cb_push_back(cb_id_in1, 1); + cb_push_back(cb_id_in0, 1); + + alignment_buffer = align_address(alignment_buffer, w_mask_to_use); // aligned for writes + input_buffer = align_address(input_buffer, r_mask_to_use); // aligned for reads + + uint64_t src_noc_addr = 0; + uint32_t data_location = 0; + + for (uint32_t h = higher_dim_start; h < higher_dim_end; h++) { + uint32_t h_offset = h * LOWER_DIMS_TIMES_REP_DIM; + uint32_t h_offset_rep = h_offset * repetitions; + for (uint32_t r = 0; r < REP_DIM; r++) { + uint32_t r_offset = r * LOWER_DIMS; + for (uint32_t l = lower_dim_start; l < lower_dim_end; l++) { + uint32_t read_offset = h_offset + r_offset + l; + src_noc_addr = s.get_noc_addr(read_offset, 0); + data_location = input_buffer + (src_noc_addr & r_offset_to_use); // Guaranteed aligned to src_noc_addr + enhanced_noc_async_read( + src_noc_addr, data_location, original_page_size_bytes); + noc_async_read_barrier(); + + for (uint32_t n = 0; n < repetitions; n++) { + // Perform the writes + uint32_t write_offset = h_offset_rep + n * LOWER_DIMS_TIMES_REP_DIM + r_offset + l; + const uint64_t dst_noc_addr = d.get_noc_addr(write_offset, 0); + if ((data_location & w_offset_to_use) != (dst_noc_addr & w_offset_to_use)) { + // Can't directly copy + const uint32_t target_align_buffer = + alignment_buffer + + (dst_noc_addr & w_offset_to_use); // Guaranteed aligned to target page addr + tt_memmove( + target_align_buffer, + data_location, + original_page_size_bytes); // Data is copied to align buffer + data_location = alignment_buffer + + (dst_noc_addr & w_offset_to_use); // Update data location to use write buffer + } + // Now we are ensured the data is at write_buffer and it is aligned for the write + // Orchestrate the write + enhanced_noc_async_write( + data_location, dst_noc_addr, original_page_size_bytes); + } + noc_async_write_barrier(); + } + } + } + return; +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_last_dim_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_last_dim_rm.cpp new file mode 100644 index 00000000000..01b7949e333 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_last_dim_rm.cpp @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Function reads from RM and writes to RM repeating the last dimension +*/ +#include +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/operations/data_movement/common/kernels/common.hpp" + +using namespace tt::data_movement::common; + +void kernel_main() { + // We are guranteed to be in 2D going to 2D + + const uint32_t src_addr = get_arg_val(0); + const uint32_t dst_addr = get_arg_val(1); + // Which set of pages to deal with + const uint32_t page_start = get_arg_val(2); + const uint32_t page_end = get_arg_val(3); + // If work is not divided up nicely between the cores/ tensor too small we can use this to not run this core. + const uint32_t nop = get_arg_val(4); + + constexpr bool tensor_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t original_page_size_bytes = get_compile_time_arg_val(1); + constexpr uint32_t num_repeats = get_compile_time_arg_val(2); + // cb_id_in0 and cb_id_in1 is each 1 page of size: + // if original_page_size_bytes is a multiple of 16, equal to original_page_size_bytes + 128 + // else if original_page_size_bytes is a multiple of 8, equal to original_page_size_bytes * 2 + 128 + // else if original_page_size_bytes is a multiple of 4, equal to original_page_size_bytes * 4 + 128 + // else if original_page_size_bytes is a multiple of 2, equal to original_page_size_bytes * 8 + 128 + // if it is an odd number equal to original_page_size_bytes * 16 + 128 + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(3); + constexpr uint32_t cb_id_in1 = get_compile_time_arg_val(4); +#define source_page_is_pow_2 (get_compile_time_arg_val(5) == 1) + constexpr uint32_t source_page_pow_2 = get_compile_time_arg_val(6); +#define dest_page_is_pow_2 (get_compile_time_arg_val(7) == 1) + constexpr uint32_t dest_page_pow_2 = get_compile_time_arg_val(8); + constexpr uint32_t dest_page_size_bytes = original_page_size_bytes * num_repeats; + // Number of times we must double the input page to make it write aligned + constexpr uint32_t num_doublings = ((original_page_size_bytes % 16) == 0) ? 0 + : ((original_page_size_bytes % 8) == 0) ? 1 + : ((original_page_size_bytes % 4) == 0) ? 2 + : ((original_page_size_bytes % 2) == 0) ? 3 + : 4; + + // Since we need to operate on a grid of cores but sometimes pages don't split properly, if nop then don't use this + // core + if (nop == 1) { + return; + } +#if source_page_is_pow_2 + // TODO: add CCL sharded native support + const InterleavedPow2AddrGen s = { + .bank_base_address = src_addr, .log_base_2_of_page_size = source_page_pow_2}; +#else + const InterleavedAddrGen s = {.bank_base_address = src_addr, .page_size = original_page_size_bytes}; +#endif +#if dest_page_is_pow_2 + const InterleavedPow2AddrGen d = { + .bank_base_address = dst_addr, .log_base_2_of_page_size = dest_page_pow_2}; +#else + const InterleavedAddrGen d = {.bank_base_address = dst_addr, .page_size = dest_page_size_bytes}; +#endif + + // Get scratchpads guaranteed to be allocated until the function terminates + cb_reserve_back(cb_id_in0, 1); + cb_reserve_back(cb_id_in1, 1); + uint32_t input_buffer = get_write_ptr(cb_id_in0); + uint32_t alignment_buffer = get_write_ptr(cb_id_in1); + cb_push_back(cb_id_in1, 1); + cb_push_back(cb_id_in0, 1); + + constexpr uint64_t r_mask_to_use = tensor_is_dram ? MASK_64 : MASK_16; + constexpr uint64_t r_offset_to_use = tensor_is_dram ? OFFSET_64 : OFFSET_16; + constexpr uint32_t r_alignment_requirement = tensor_is_dram ? 64 : 16; + constexpr uint32_t w_alignment_requirement = 16; + constexpr uint64_t w_mask_to_use = MASK_16; + constexpr uint64_t w_offset_to_use = OFFSET_16; + + alignment_buffer = + align_address(alignment_buffer, w_mask_to_use); // Guaranteed aligned for write + input_buffer = align_address(input_buffer, r_mask_to_use); // Guaranteed aligned for reads + + uint32_t cur_page_size = original_page_size_bytes; + for (uint32_t i = page_start; i < page_end; i++) { + // Read from source + uint64_t src_noc_addr = s.get_noc_addr(i, 0); + uint64_t dst_noc_addr = d.get_noc_addr(i, 0); + uint32_t data_location = + input_buffer + (src_noc_addr & r_offset_to_use); // Guaranteed to be aligned for our read + enhanced_noc_async_read(src_noc_addr, data_location, original_page_size_bytes); + cur_page_size = original_page_size_bytes; + noc_async_read_barrier(); + if constexpr (num_doublings != 0) { + // The if is not needed but it is just for performance as the vast majority of times num_doublings will be 0 + // and we don't want target offset to be allocated and the for loop bounds computed + uint32_t target_offset = original_page_size_bytes; + for (uint32_t j = 0; j < num_doublings; j++) { + // This ensures the cur_page_size will be alligned to 16B so future walk retains allignment + tt_memmove( + data_location + target_offset, data_location, cur_page_size); + target_offset += cur_page_size; + cur_page_size *= 2; + } + } + // Write to destination + // data is at data_location and there is cur_page_size bytes worth of data there + if ((data_location & w_offset_to_use) != (dst_noc_addr & w_offset_to_use)) { + // Can't directly copy due to alignment + tt_memmove( + alignment_buffer + (dst_noc_addr & w_offset_to_use), data_location, cur_page_size); + data_location = alignment_buffer + (dst_noc_addr & w_offset_to_use); + } + + uint64_t num_written = 0; + while (num_written < dest_page_size_bytes) { + // Either write out the whole input buffer or however much is left + uint32_t to_write = (dest_page_size_bytes - num_written) > cur_page_size + ? cur_page_size + : (dest_page_size_bytes - num_written); + enhanced_noc_async_write(data_location, dst_noc_addr + num_written, to_write); + num_written += to_write; + } + noc_async_write_barrier(); + } + return; +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp new file mode 100644 index 00000000000..e8266b2ee50 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.cpp @@ -0,0 +1,262 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 +#include +#include +#include + +#include +#include +#include + +#include "ttnn/core.hpp" +#include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/operations/cb_utils.hpp" +#include "ttnn/operations/math.hpp" +#include "ttnn/operation.hpp" +#include "ttnn/operations/core/work_split/work_split_tilize.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/types.hpp" + +constexpr uint32_t READ_ALIGNMENT = 64; + +namespace ttnn::operations::data_movement::repeat { + +tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim( + // We are repeating the last dim on a 2D shape + const Tensor& input, + uint32_t num_repeats, + const Tensor& output) { + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + // get datum size + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + const uint32_t data_size = input.element_size(); + tt::tt_metal::IDevice* device = input.device(); + // Multi device pre-computation + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + tt::log_debug("row major reshape"); + tt::log_debug("input shape: {}", input_log_shape); + tt::log_debug("output shape: {}", output_log_shape); + tt::log_debug("data size: {}", data_size); + uint32_t source_page_size_bytes = input_log_shape[-1] * data_size; + uint32_t dest_page_size_bytes = source_page_size_bytes * num_repeats; + TT_FATAL( + dest_page_size_bytes == output_log_shape[-1] * data_size, + "Data size of output does not match requirement for repeat last dim"); + uint32_t read_start_page = 0; + tt::tt_metal::Buffer* src_buffer = input.buffer(); + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_FATAL(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + // Find how many input pages each core is responsible for so that we always start at the begining of a read and + // write page Since the logical volumes match, we are guaranteed that the very last page is aligned + uint32_t number_of_pages = input_log_shape[-2]; + uint32_t responsibility = ((number_of_pages - 1) / num_cores_total) + 1; + uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t cb_size_bytes = READ_ALIGNMENT * 2 + (source_page_size_bytes & 0xF) == 0 ? source_page_size_bytes + : (source_page_size_bytes & 0x7) == 0 ? source_page_size_bytes * 2 + : (source_page_size_bytes & 0x3) == 0 ? source_page_size_bytes * 4 + : (source_page_size_bytes & 0x1) == 0 ? source_page_size_bytes * 8 + : source_page_size_bytes * 16; + uint32_t src0_cb_index = 0; + uint32_t src1_cb_index = 1; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_size_bytes); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, cb_size_bytes); + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); + bool source_page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(source_page_size_bytes); + uint32_t source_page_pow_2 = source_page_is_pow_2 ? (std::uint32_t)std::log2(source_page_size_bytes) : 0; + bool dest_page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(dest_page_size_bytes); + uint32_t dest_page_pow_2 = dest_page_is_pow_2 ? (std::uint32_t)std::log2(dest_page_size_bytes) : 0; + std::vector compile_time_args = { + (std::uint32_t)src0_is_dram, + (std::uint32_t)source_page_size_bytes, + (std::uint32_t)num_repeats, + src0_cb_index, + src1_cb_index, + source_page_is_pow_2, + source_page_pow_2, + dest_page_is_pow_2, + dest_page_pow_2}; + + tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_last_dim_rm.cpp", + total_cores, + tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); + uint32_t done = 0; + for (int core_x = 0; core_x < num_cores_x; core_x++) { + for (int core_y = 0; core_y < num_cores_y; core_y++) { + CoreCoord core = {core_x, core_y}; + if (done == 1) { + const std::vector reader_runtime_args = { + src_buffer->address(), dst_buffer->address(), 0, 0, 1}; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } else { + // set the runtime args + // set the compile time args + const uint32_t start_of_read = read_start_page; + uint32_t end_of_read = read_start_page + responsibility; + end_of_read = end_of_read < number_of_pages ? end_of_read : number_of_pages; + + const std::vector reader_runtime_args = { + src_buffer->address(), dst_buffer->address(), start_of_read, end_of_read, 0 + + }; + read_start_page = end_of_read; + done = (end_of_read == input_log_shape[-2]) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } + } + } + return {.program = std::move(program)}; +} + +tt::tt_metal::operation::ProgramWithCallbacks rm_repeater( + // We are repeating the second dim on a 4D shape + const Tensor& input, + uint32_t num_repeats, + const Tensor& output) { + tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); + // get datum size + tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + const uint32_t data_size = input.element_size(); + tt::tt_metal::IDevice* device = input.device(); + // Multi device pre-computation + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + uint32_t num_cores_total = num_cores_x * num_cores_y; + CoreRange total_cores({0, 0}, {num_cores_x - 1, num_cores_y - 1}); + + ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view()); + ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view()); + tt::log_debug("row major reshape"); + tt::log_debug("input shape: {}", input_log_shape); + tt::log_debug("output shape: {}", output_log_shape); + tt::log_debug("data size: {}", data_size); + uint32_t page_size_bytes = input_log_shape[3] * data_size; + TT_ASSERT( + page_size_bytes == output_log_shape[3] * data_size, + "Data size of output does not match requirement for repeat last dim"); + uint32_t read_start_page = 0; + tt::tt_metal::Buffer* src_buffer = input.buffer(); + tt::tt_metal::Buffer* dst_buffer = output.buffer(); + TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + // Find how many input pages each core is responsible for so that we always start at the begining of a read and + // write page Since the logical volumes match, we are guaranteed that the very last page is aligned + uint32_t number_of_higher_pages = input_log_shape[0]; + uint32_t number_of_lower_pages = input_log_shape[2]; + uint32_t number_of_rep_dim_pages = input_log_shape[1]; + uint32_t src0_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t cb_size_bytes = READ_ALIGNMENT * 2 + page_size_bytes; + uint32_t src0_cb_index = 0; + uint32_t src1_cb_index = 1; + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src0_cb_index, cb_data_format}}) + .set_page_size(src0_cb_index, cb_size_bytes); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config); + tt::tt_metal::CircularBufferConfig cb_src1_config = + tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src1_cb_index, cb_data_format}}) + .set_page_size(src1_cb_index, cb_size_bytes); + + auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config); + bool page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(page_size_bytes); + uint32_t page_pow_2 = page_is_pow_2 ? (std::uint32_t)std::log2(page_size_bytes) : 0; + std::vector compile_time_args = { + (std::uint32_t)src0_is_dram, + (std::uint32_t)page_size_bytes, + src0_cb_index, + src1_cb_index, + page_is_pow_2, + page_pow_2, + number_of_lower_pages, + number_of_rep_dim_pages}; + + tt::tt_metal::KernelHandle reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/repeat/device/device/repeat_higher_dim_rm.cpp", + total_cores, + tt::tt_metal::ReaderDataMovementConfig(compile_time_args)); + uint32_t done = 0; + // Determine runtime argumens + bool divide_on_higher = number_of_higher_pages > number_of_lower_pages; + + uint32_t responsibility_chunk = + (divide_on_higher ? number_of_higher_pages : number_of_lower_pages) / num_cores_total; + uint32_t responsibility_mod = (divide_on_higher ? number_of_higher_pages : number_of_lower_pages) % num_cores_total; + uint32_t core_count = 0; + for (int core_x = 0; core_x < num_cores_x; core_x++) { + for (int core_y = 0; core_y < num_cores_y; core_y++) { + uint32_t responsibility = + core_count++ < responsibility_mod ? responsibility_chunk + 1 : responsibility_chunk; + CoreCoord core = {core_x, core_y}; + if (done == 1) { + const std::vector reader_runtime_args = {0, 0, 0, 0, 0, 0, 0, 1}; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } else if (divide_on_higher) { + // set the runtime args + // set the compile time args + const uint32_t start_of_read = read_start_page; + uint32_t end_of_read = read_start_page + responsibility; + end_of_read = end_of_read < number_of_higher_pages ? end_of_read : number_of_higher_pages; + + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + start_of_read, + end_of_read, + 0, + number_of_lower_pages, + num_repeats, + 0}; + read_start_page = end_of_read; + done = (end_of_read == number_of_higher_pages) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } else { + // set the runtime args + // set the compile time args + const uint32_t start_of_read = read_start_page; + uint32_t end_of_read = read_start_page + responsibility; + end_of_read = end_of_read < number_of_lower_pages ? end_of_read : number_of_lower_pages; + + const std::vector reader_runtime_args = { + src_buffer->address(), + dst_buffer->address(), + 0, + number_of_higher_pages, + start_of_read, + end_of_read, + num_repeats, + 0}; + read_start_page = end_of_read; + done = (end_of_read == number_of_lower_pages) ? 1 : 0; + tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); + } + } + } + return {.program = std::move(program)}; +} + +tt::tt_metal::operation::ProgramWithCallbacks rm_repeat_program_factory( + const Tensor& input, uint32_t num_repeats, const Tensor& output, bool is_last_dim) { + // We are repeating the second dim. If is_last_dim then the tensor is 2D. + // otherwise it is 4D. + if (is_last_dim) { + return rm_repeater_last_dim(input, num_repeats, output); + } else { + return rm_repeater(input, num_repeats, output); + } +} + +}; // namespace ttnn::operations::data_movement::repeat diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.hpp new file mode 100644 index 00000000000..f97c29ea716 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operation.hpp" +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::data_movement::repeat { + +tt::tt_metal::operation::ProgramWithCallbacks rm_repeat_program_factory( + const Tensor& input, uint32_t num_repeats, const Tensor& output, bool is_last_dim); + +} // namespace ttnn::operations::data_movement::repeat diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_interleaved_start_id.cpp deleted file mode 100644 index ec1c4b5466e..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_interleaved_start_id.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - -// This repeat kernel is forked off the concat kernel -// so this kernel currently does one read per tile in output -// but since this is the same tensor and not unique tensors like concat we can -// reduce reads by reusing the tiles we have read instead of constantly -// rereading them - -// Make n reads defined by num_reads -// Writes to Specified Circular Buffers in L1 -// Expects n provided src_addr, src_noc_x, src_noc_y, and cb_id_in -void kernel_main() { - const uint32_t src_addr = get_arg_val(0); - const uint32_t num_tiles = get_arg_val(1); - const uint32_t num_tiles_per_block = get_arg_val(2); - uint32_t curr_repeat_idx = get_arg_val(3); - uint32_t curr_idx_in_block = get_arg_val(4); - uint32_t curr_block_start_id = get_arg_val(5); - uint32_t curr_id = get_arg_val(6); - - constexpr uint32_t cb_id_in = get_compile_time_arg_val(0); - constexpr uint32_t src_is_dram = get_compile_time_arg_val(1); - constexpr uint32_t num_repeats = get_compile_time_arg_val(2); - - // ublocks size defined in tiles - constexpr uint32_t ublock_size_tiles = 1; - const uint32_t tile_size_bytes = get_tile_size(cb_id_in); - const DataFormat data_format = get_dataformat(cb_id_in); - - InterleavedAddrGenFast src_addr_gen = { - .bank_base_address = src_addr, .page_size = tile_size_bytes, .data_format = data_format}; - - for (uint32_t i = 0; i < num_tiles; ++i) { - cb_reserve_back(cb_id_in, ublock_size_tiles); - uint32_t l1_write_addr = get_write_ptr(cb_id_in); - noc_async_read_tile(curr_id, src_addr_gen, l1_write_addr); - curr_id++; - curr_idx_in_block++; - noc_async_read_barrier(); - cb_push_back(cb_id_in, ublock_size_tiles); - - if (curr_idx_in_block == num_tiles_per_block) { - curr_idx_in_block = 0; - curr_repeat_idx++; - if (curr_repeat_idx == num_repeats) { - curr_repeat_idx = 0; - curr_block_start_id = curr_id; - } else { - curr_id = curr_block_start_id; - } - } - } -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp deleted file mode 100644 index ea1f0492c3c..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/reader_repeat_stick_layout_interleaved_start_id.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include "dataflow_api.h" - -// This repeat kernel is forked off the concat kernel -// so this kernel currently does one read per tile in output -// but since this is the same tensor and not unique tensors like concat we can -// reduce reads by reusing the tiles we have read instead of constantly -// rereading them - -// Make n reads defined by num_reads -// Writes to Specified Circular Buffers in L1 -// Expects n provided src_addr, src_noc_x, src_noc_y, and cb_id_in -void kernel_main() { - const uint32_t src_addr = get_arg_val(0); - const uint32_t num_pages = get_arg_val(1); - const uint32_t num_pages_per_block = get_arg_val(2); - uint32_t curr_repeat_idx = get_arg_val(3); - uint32_t curr_idx_in_block = get_arg_val(4); - uint32_t curr_block_start_id = get_arg_val(5); - uint32_t curr_id = get_arg_val(6); - const uint32_t page_size = get_arg_val(7); - - constexpr uint32_t cb_id_in = get_compile_time_arg_val(0); - constexpr uint32_t src_is_dram = get_compile_time_arg_val(1); - constexpr uint32_t num_repeats = get_compile_time_arg_val(2); - - constexpr uint32_t ublock_size_pages = 1; - - InterleavedAddrGen src_addr_gen = {.bank_base_address = src_addr, .page_size = page_size}; - - for (uint32_t i = 0; i < num_pages; ++i) { - cb_reserve_back(cb_id_in, ublock_size_pages); - uint32_t l1_write_addr = get_write_ptr(cb_id_in); -#ifdef WIDTH_REPEAT - noc_async_read_page(curr_id, src_addr_gen, l1_write_addr); - uint64_t local_read_addr = get_noc_addr(l1_write_addr); - l1_write_addr += page_size; - noc_async_read_barrier(); - for (uint32_t j = 1; j < num_repeats; ++j) { - noc_async_read(local_read_addr, l1_write_addr, page_size); - l1_write_addr += page_size; - } - curr_id++; - noc_async_read_barrier(); -#else - noc_async_read_page(curr_id, src_addr_gen, l1_write_addr); - curr_id++; - curr_idx_in_block++; - - if (curr_idx_in_block == num_pages_per_block) { - curr_idx_in_block = 0; - curr_repeat_idx++; - if (curr_repeat_idx == num_repeats) { - curr_repeat_idx = 0; - curr_block_start_id = curr_id; - } else { - curr_id = curr_block_start_id; - } - } - noc_async_read_barrier(); -#endif - cb_push_back(cb_id_in, ublock_size_pages); - } -} diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.cpp new file mode 100644 index 00000000000..5e38b7aa6b0 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.cpp @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttnn/operations/data_movement/repeat/device/host/repeat_program_factory.hpp" +#include "ttnn/operations/data_movement/repeat/device/repeat_device_operation.hpp" + +namespace ttnn { + +void RepeatDeviceOperation::validate(const std::vector& input_tensors) const { + // Validate the input tensor + const Tensor& input_tensor_a = input_tensors.at(0); + TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to reshape need to be on device!"); + TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands need to be allocated in buffers on device!"); + TT_FATAL(input_tensor_a.get_layout() == Layout::ROW_MAJOR, "This function is for RM->RM"); + TT_FATAL( + input_tensor_a.get_dtype() == DataType::BFLOAT16 or input_tensor_a.get_dtype() == DataType::UINT32 or + input_tensor_a.get_dtype() == DataType::FLOAT32, + "Can only work with bfloat16/float32 or uint32 tensors"); + // is this relevant? + TT_FATAL( + this->m_output_mem_config.memory_layout == input_tensor_a.memory_config().memory_layout, + "Output tensor must have the same memory layout as input tensor"); +} + +std::vector RepeatDeviceOperation::compute_output_specs(const std::vector& input_tensors) const { + const auto& input_tensor_a = input_tensors.at(0); + auto output_shape = input_tensor_a.get_logical_shape(); + output_shape[m_is_last_dim ? -1 : 1] *= m_num_repeats; + + auto mem_config = this->m_output_mem_config; + if (input_tensor_a.memory_config().is_sharded()) { + auto shard_spec = input_tensor_a.shard_spec().value(); + shard_spec.shape[0] = output_shape[0]; + mem_config.shard_spec = shard_spec; + } + return {TensorSpec( + output_shape, + TensorLayout::fromPaddedShape( + input_tensor_a.get_dtype(), + PageConfig(input_tensor_a.get_layout()), + mem_config, + output_shape, + output_shape))}; // no padding requried because we are RM only right now +} + +std::vector RepeatDeviceOperation::create_output_tensors(const std::vector& input_tensors) const { + // Create the output tensor + const auto& input_tensor_a = input_tensors.at(0); + const auto output_shape = this->compute_output_specs(input_tensors).at(0).logical_shape(); + + // is this relevant? + auto mem_config = this->m_output_mem_config; + if (input_tensor_a.memory_config().is_sharded()) { + auto shard_spec = input_tensor_a.shard_spec().value(); + shard_spec.shape[0] = output_shape[0]; + mem_config.shard_spec = shard_spec; + } + return {create_device_tensor( + output_shape, input_tensor_a.get_dtype(), input_tensor_a.get_layout(), input_tensor_a.device(), mem_config)}; +} + +operation::ProgramWithCallbacks RepeatDeviceOperation::create_program( + const std::vector& input_tensors, std::vector& output_tensors) const { + return operations::data_movement::repeat::rm_repeat_program_factory( + input_tensors.at(0), m_num_repeats, output_tensors.at(0), m_is_last_dim); +} +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.hpp new file mode 100644 index 00000000000..7ae7d881b80 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_device_operation.hpp @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/run_operation.hpp" +namespace ttnn { + +struct RepeatDeviceOperation { + const uint32_t m_num_repeats; + const bool m_is_last_dim; + MemoryConfig m_output_mem_config; + + // Required functions to all tensor op functions + void validate(const std::vector& input_tensors) const; + std::vector compute_output_specs(const std::vector& input_tensors) const; + std::vector create_output_tensors(const std::vector& input_tensors) const; + operation::ProgramWithCallbacks create_program( + const std::vector& input_tensors, std::vector& output_tensors) const; +}; +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp deleted file mode 100644 index 7fcf2e0dc36..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.cpp +++ /dev/null @@ -1,49 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "repeat_op.hpp" -#include "repeat_program_factory.hpp" -#include "ttnn/tensor/tensor_utils.hpp" - -using namespace tt::constants; -using namespace tt::tt_metal; - -namespace ttnn::operations::data_movement { - -void RepeatDeviceOperation::validate(const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors[0]; - auto input_shape = input_tensor.get_padded_shape(); - TT_FATAL(this->repeat_dim < input_shape.rank(), "Repeat dim specified is larger than input tensor rank."); - if (input_tensor.get_layout() == Layout::ROW_MAJOR && this->repeat_dim == input_shape.rank() - 1) { - TT_FATAL( - (input_shape[this->repeat_dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == 0, - "The last dim of tensor being repeated must be 32 byte aligned for DRAM Tensor and 16 byte aligned for L1 " - "tensor"); - } - TT_FATAL(this->num_repeats > 0, "Number of repeats should be greater than 0"); - TT_FATAL(input_tensor.buffer(), "Operand to repeat needs to be allocated in a buffer on device."); - TT_FATAL(input_tensor.device(), "Operand to repeat needs to be on device."); - TT_FATAL( - input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, - "Input to repeat must be interleaved."); - TT_FATAL( - this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, - "Output of repeat must be interleaved."); -} - -std::vector RepeatDeviceOperation::compute_output_specs( - const std::vector& input_tensors) const { - const auto& input_tensor = input_tensors.at(0); - ttnn::Shape shape_out = input_tensor.get_logical_shape(); - shape_out[this->repeat_dim] *= this->num_repeats; - return {TensorSpec( - shape_out, TensorLayout(input_tensor.get_dtype(), PageConfig(input_tensor.get_layout()), output_mem_config))}; -} - -operation::ProgramWithCallbacks RepeatDeviceOperation::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - return detail::repeat_multi_core(input_tensors[0], this->repeat_dim, this->num_repeats, output_tensors[0]); -} - -} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.hpp deleted file mode 100644 index a406f33fd21..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_op.hpp +++ /dev/null @@ -1,23 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/run_operation.hpp" - -namespace ttnn::operations::data_movement { - -struct RepeatDeviceOperation { - const uint32_t repeat_dim; - const uint32_t num_repeats; - const tt::tt_metal::MemoryConfig output_mem_config; - - void validate(const std::vector& input_tensors) const; - std::vector compute_output_specs(const std::vector& input_tensors) const; - tt::tt_metal::operation::ProgramWithCallbacks create_program( - const std::vector& input_tensors, std::vector& output_tensors) const; -}; - -} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp deleted file mode 100644 index cd809ec61f8..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.cpp +++ /dev/null @@ -1,208 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include -#include -#include -#include -#include "ttnn/operation.hpp" - -using namespace tt::constants; -using namespace tt::tt_metal; - -namespace ttnn::operations::data_movement::detail { - -operation::ProgramWithCallbacks repeat_multi_core( - const Tensor& input_tensor, const uint32_t repeat_dim, const uint32_t num_repeats, const Tensor& output) { - tt::tt_metal::Program program = tt::tt_metal::CreateProgram(); - - tt::tt_metal::IDevice* device = output.device(); - - const tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); - - const bool rm_layout = output.get_layout() == Layout::ROW_MAJOR; - - constexpr bool rm_orientation = false; - - uint32_t num_output_pages; - uint32_t single_page_size; - if (rm_layout) { - num_output_pages = output.volume() / output.get_padded_shape()[-1]; - single_page_size = - tt::align(output.element_size() * output.get_padded_shape()[-1], output.buffer()->alignment()); - } else { - num_output_pages = output.volume() / TILE_HW; - single_page_size = tt::tt_metal::detail::TileSize(cb_data_format); - } - - auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); - uint32_t num_cores_x = compute_with_storage_grid_size.x; - uint32_t num_cores_y = compute_with_storage_grid_size.y; - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_output_pages, rm_orientation); - - tt::tt_metal::Buffer* dst_buffer = output.buffer(); - TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - - uint32_t src0_cb_index = 0; - uint32_t num_input_pages = 2; - tt::tt_metal::CircularBufferConfig cb_src0_config = - tt::tt_metal::CircularBufferConfig(num_input_pages * single_page_size, {{src0_cb_index, cb_data_format}}) - .set_page_size(src0_cb_index, single_page_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); - - uint32_t num_dims = output.get_padded_shape().rank(); - - auto input_buffer = input_tensor.buffer(); - uint32_t src_addr = input_buffer->address(); - uint32_t src_is_dram = input_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - uint32_t src_page_size = input_buffer->page_size(); - uint32_t num_pages_per_block; - - uint32_t num_accum_pages = 1; - uint32_t scale_factor = 1; - - // RM is special cased in the loop (dim_units = 1 for last dim else it's the dim size) - if (!rm_layout) { - if (repeat_dim == num_dims - 2) { - scale_factor = TILE_HEIGHT; - } else if (repeat_dim == num_dims - 1) { - scale_factor = TILE_WIDTH; - } - } - - for (uint32_t i = repeat_dim + 1; i < num_dims; ++i) { - num_accum_pages *= output.get_padded_shape()[i]; - } - if (rm_layout) { - if (num_dims > 1 && repeat_dim < num_dims - 1) { - num_accum_pages /= output.get_padded_shape()[-1]; - } - } else { - if (repeat_dim < num_dims - 2) { - num_accum_pages /= TILE_HW; - } else if (repeat_dim == num_dims - 2) { - num_accum_pages /= TILE_WIDTH; - } - } - - if (rm_layout) { - if (repeat_dim == num_dims - 1) { - num_pages_per_block = num_accum_pages; - } else { - uint32_t dim_pages = input_tensor.get_padded_shape()[repeat_dim]; - num_pages_per_block = num_accum_pages * dim_pages; - } - } else { - uint32_t dim_pages = input_tensor.get_padded_shape()[repeat_dim] / scale_factor; - num_pages_per_block = num_accum_pages * dim_pages; - } - - std::vector reader_kernel_args = {src_addr, 0, num_pages_per_block, 0, 0, 0, 0}; - if (rm_layout) { - reader_kernel_args.push_back(src_page_size); - } - - // Reader compile-time args - // Data is 32 byte aligned - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args = {// interleaved accessor args - (std::uint32_t)src0_cb_index, - (std::uint32_t)src_is_dram, - (std::uint32_t)num_repeats}; - - std::map repeat_defines; - - if (rm_layout && repeat_dim == num_dims - 1) { - repeat_defines["WIDTH_REPEAT"] = "1"; - } - - std::vector writer_compile_time_args = {// interleaved accessor args - (std::uint32_t)src0_cb_index, - (std::uint32_t)dst_is_dram}; - - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( - program, - rm_layout ? "ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/" - "reader_repeat_stick_layout_interleaved_start_id.cpp" - : "ttnn/cpp/ttnn/operations/data_movement/repeat/device/kernels/dataflow/" - "reader_repeat_interleaved_start_id.cpp", - all_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, repeat_defines)); - - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( - program, - rm_layout - ? "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/writer_unary_stick_layout_interleaved_start_id.cpp" - : "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", - all_cores, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args)); - - const auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, rm_orientation); - uint32_t g1_num_cores = core_group_1.num_cores(); - for (uint32_t i = 0, num_pages_written = 0; i < cores.size(); ++i) { - const CoreCoord& core = cores[i]; - uint32_t num_pages_per_core = 0; - if (i < g1_num_cores) { - num_pages_per_core = num_tiles_per_core_group_1; - } else { - num_pages_per_core = num_tiles_per_core_group_2; - } - uint32_t curr_repeat_idx = 0; - uint32_t curr_idx_in_block = 0; - uint32_t curr_block_start_id = 0; - uint32_t curr_id = 0; - if (rm_layout && repeat_dim == num_dims - 1) { - curr_id = num_pages_written; - } else { - curr_repeat_idx = num_pages_written / num_pages_per_block % num_repeats; - curr_idx_in_block = num_pages_written % num_pages_per_block; - curr_block_start_id = num_pages_written / (num_pages_per_block * num_repeats) * num_pages_per_block; - curr_id = curr_block_start_id + curr_idx_in_block; - } - - reader_kernel_args[1] = num_pages_per_core; - reader_kernel_args[3] = curr_repeat_idx; - reader_kernel_args[4] = curr_idx_in_block; - reader_kernel_args[5] = curr_block_start_id; - reader_kernel_args[6] = curr_id; - - std::vector writer_kernel_args; - if (rm_layout) { - writer_kernel_args = { - dst_buffer->address(), output.buffer()->page_size(), num_pages_per_core, num_pages_written}; - } else { - writer_kernel_args = {dst_buffer->address(), num_pages_per_core, num_pages_written}; - } - tt::tt_metal::SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_kernel_args); - - tt::tt_metal::SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_kernel_args); - num_pages_written += num_pages_per_core; - } - - auto override_runtime_args_callback = [unary_reader_kernel_id, unary_writer_kernel_id, cores]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto src_buffer = input_buffers.at(0); - - auto dst_buffer = output_buffers.at(0); - - for (const auto& core : cores) { - { - auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); - runtime_args[0] = src_buffer->address(); - } - - { - auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core); - runtime_args[0] = dst_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; -} - -} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp deleted file mode 100644 index 0e8e3260f90..00000000000 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/device/repeat_program_factory.hpp +++ /dev/null @@ -1,13 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 -#pragma once - -#include - -namespace ttnn::operations::data_movement::detail { - -tt::tt_metal::operation::ProgramWithCallbacks repeat_multi_core( - const Tensor& input_tensor, const uint32_t repeat_dim, const uint32_t num_repeats, const Tensor& output); - -} // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp index bb19b3a5de9..3558ee8ce63 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.cpp @@ -1,126 +1,218 @@ -// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/operations/data_movement/repeat/repeat.hpp" +#include + +#include +#include +#include -#include "device/repeat_op.hpp" -#include #include "ttnn/common/constants.hpp" -#include "ttnn/decorators.hpp" -#include "ttnn/operations/data_movement/pad/pad.hpp" -#include "ttnn/operations/data_movement/reshape_view/reshape.hpp" -#include "ttnn/operations/data_movement/slice/slice.hpp" -#include "ttnn/operations/data_movement/tilize/tilize.hpp" -#include "ttnn/operations/data_movement/untilize/untilize.hpp" +#include "ttnn/operations/core/core.hpp" +#include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp" +#include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" +#include "ttnn/operations/data_movement/view/view.hpp" +#include "ttnn/operations/functions.hpp" #include "ttnn/run_operation.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "device/repeat_device_operation.hpp" +#include "repeat.hpp" namespace ttnn::operations::data_movement { -ttnn::Tensor RepeatOperation::invoke( +namespace detail { + +struct UpperRepeatDims { + static constexpr uint32_t collapsed_upper = 0; + static constexpr uint32_t repeat = 1; + static constexpr uint32_t collapsed_lower = 2; + static constexpr uint32_t page_size = 3; +}; +struct LastRepeatDims { + static constexpr uint32_t collapsed_upper = 0; + static constexpr uint32_t repeat = 1; +}; + +ttnn::Tensor repeat_upper_dims_rm( + const ttnn::Tensor& tensor, + const uint32_t dim, + const uint32_t repetitions, uint8_t queue_id, - const ttnn::Tensor& input_tensor, - const ttnn::Shape& repeat_dims, - const std::optional& memory_config_arg) { - auto padded_input_shape = input_tensor.get_padded_shape(); - auto logical_input_shape = input_tensor.get_logical_shape(); - auto input_rank = logical_input_shape.rank(); - - auto repeated_logical_shape = logical_input_shape; - for (uint32_t dim = 0; dim < input_rank; ++dim) { - repeated_logical_shape[dim] *= repeat_dims[dim]; + const MemoryConfig& output_mem_config) { + // collapse upper dims to 4D or append 1s + // collapse lower dims or insert 1s + // op + // un-collaps to expected size + + // figure out the shape of the input tensor for the op. dims before and after rep dim get collapsed, not including + // page size. + const auto& input_shape = tensor.get_logical_shape(); + ttnn::SmallVector collapsed_shape_vector(4); + + collapsed_shape_vector[UpperRepeatDims::collapsed_upper] = + std::accumulate(input_shape.cbegin(), input_shape.cbegin() + dim, 1, std::multiplies()); + collapsed_shape_vector[UpperRepeatDims::repeat] = input_shape[dim]; + collapsed_shape_vector[UpperRepeatDims::collapsed_lower] = + std::accumulate(input_shape.cbegin() + dim + 1, input_shape.cend() - 1, 1, std::multiplies()); + collapsed_shape_vector[UpperRepeatDims::page_size] = input_shape[-1]; + + // use ttnn::view to check logic + auto input_tensor = ttnn::view(tensor, ttnn::Shape(collapsed_shape_vector)); + + constexpr bool is_final_dim = false; + auto out_tensor = + operation::run( + RepeatDeviceOperation{repetitions, is_final_dim, output_mem_config}, {input_tensor}, {}, {}, queue_id) + .at(0); + auto expected_shape = input_shape; + expected_shape[dim] *= repetitions; + + return ttnn::view(out_tensor, ttnn::Shape(expected_shape)); +} + +ttnn::Tensor repeat_last_dim_rm( + const ttnn::Tensor& tensor, const uint32_t repetitions, uint8_t queue_id, const MemoryConfig& output_mem_config) { + // collapse to 2D + // op + // un-collapse + const auto& input_shape = tensor.get_logical_shape(); + ttnn::SmallVector collapsed_shape_vector(2); + + collapsed_shape_vector[0] = + std::accumulate(input_shape.cbegin(), input_shape.cend() - 1, 1, std::multiplies()); + collapsed_shape_vector[1] = input_shape[-1]; + + // use ttnn:view + auto input_tensor = ttnn::view(tensor, ttnn::Shape(collapsed_shape_vector)); + + constexpr bool is_final_dim = true; + auto out_tensor = + operation::run( + RepeatDeviceOperation{repetitions, is_final_dim, output_mem_config}, {input_tensor}, {}, {}, queue_id) + .at(0); + + auto expected_shape = input_shape; + expected_shape[-1] *= repetitions; + + return ttnn::view(out_tensor, ttnn::Shape(expected_shape)); +} + +std::tuple> match_input_rank( + const ttnn::Tensor& tensor, const SmallVector& repetition_vector) { + auto working_tensor = tensor; + const auto& input_shape = working_tensor.get_logical_shape(); + SmallVector working_repetition_vector; + + const auto total_reps = + std::accumulate(repetition_vector.cbegin(), repetition_vector.cend(), 1, std::multiplies()); + + if (input_shape.rank() < repetition_vector.size()) { + ttnn::SmallVector new_shape_vec(repetition_vector.size(), 1); + std::copy_backward(input_shape.cbegin(), input_shape.cend(), new_shape_vec.end()); + working_tensor = ttnn::view(working_tensor, ttnn::Shape(new_shape_vec)); + working_repetition_vector = std::move(repetition_vector); + } + // torch actually throws an error if the repetition rank is smaller than the tensor rank but it seems reasonable to + // handle it + else if (repetition_vector.size() < input_shape.rank()) { + working_repetition_vector.resize(input_shape.rank(), 1); + std::copy_backward(repetition_vector.cbegin(), repetition_vector.cend(), working_repetition_vector.end()); } - std::vector output_tensors = {Tensor(tt::tt_metal::operation::get_workers_for_op_output({input_tensor}))}; - tt::tt_metal::operation::launch_op( - [&input_rank, &input_tensor, &repeat_dims, &memory_config_arg, &padded_input_shape]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) -> std::vector { - auto memory_config = memory_config_arg.value_or(input_tensor.memory_config()); - TT_FATAL(repeat_dims.rank() == input_rank, "Number of repeat dims must be equal to number of tensor dims"); - Tensor output = input_tensor; - for (uint32_t dim = 0; dim < repeat_dims.rank(); ++dim) { - if (repeat_dims[dim] == 1) { - continue; - } - TT_FATAL(repeat_dims[dim] > 0, "Number of repetitions along a dim must be greater than 0"); - if (input_tensor.get_layout() == Layout::ROW_MAJOR && dim == input_rank - 1) { - TT_FATAL( - (padded_input_shape[dim] * input_tensor.element_size()) % input_tensor.buffer()->alignment() == - 0, - "Current repeat implementation requires last dim ({}) to be aligned to {} repeating on last " - "dim", - (padded_input_shape[dim] * input_tensor.element_size()), - input_tensor.buffer()->alignment()); - } - auto outputs = operation::run_without_autoformat( - RepeatDeviceOperation{dim, repeat_dims[dim], memory_config}, {output}); - TT_FATAL( - outputs.size() == 1, - "ttnn.repeat: expected 1 output tensor from run_without_autoformat, but got {}", - outputs.size()); - output = outputs[0]; - } - return {output}; - }, - {}, - output_tensors); - TT_FATAL(output_tensors.size() == 1, "ttnn.repeat: expected 1 output tensor, but got {}", output_tensors.size()); - if (input_tensor.get_layout() != Layout::ROW_MAJOR && logical_input_shape != padded_input_shape) { - auto zero_indices = ttnn::SmallVector(input_rank, 0); - auto end_indices = ttnn::SmallVector(repeated_logical_shape.cbegin(), repeated_logical_shape.cend()); - auto step = ttnn::SmallVector(input_rank, 1); - - if (repeated_logical_shape.volume() % tt::constants::TILE_HW != 0) { - // volume of the repeated tensor doesn't fit neatly into tiles. - // slice/tilize don't support padding to tiled on the output for - // now, so we need to perform the slice in row-major then re-tilize - // ourselves. - auto rm_output = ttnn::untilize(output_tensors[0]); - auto sliced_output = - ttnn::slice(rm_output, zero_indices, end_indices, step, input_tensor.memory_config(), std::nullopt); - - auto sliced_logical_shape = sliced_output.get_logical_shape(); - auto sliced_padded_shape = sliced_output.get_padded_shape(); - - if (sliced_padded_shape.volume() % tt::constants::TILE_HW == 0) { - // slice preserved tile padding for us, so we can just tilize now. - auto tiled_output = ttnn::tilize(sliced_output, input_tensor.memory_config()); - return tiled_output; - } - - auto padded_height = tt::round_up(sliced_padded_shape[-2], tt::constants::TILE_HEIGHT); - auto padded_width = tt::round_up(sliced_padded_shape[-1], tt::constants::TILE_WIDTH); - TT_ASSERT(input_rank >= 2, "ttnn.repeat: rank of tiled input tensor must be >= 2"); - uint32_t num_non_hw_dims = input_rank - 2u; - auto padding_vec = ttnn::SmallVector>(num_non_hw_dims, {0, 0}); - padding_vec.reserve(input_rank); - padding_vec.emplace_back(0, padded_height - sliced_padded_shape[-2]); - padding_vec.emplace_back(0, padded_width - sliced_padded_shape[-1]); - - constexpr bool pad_use_multicore = true; - auto padded_output = ttnn::pad(queue_id, sliced_output, padding_vec, 0.0f, pad_use_multicore, std::nullopt); - auto tiled_output = ttnn::tilize(padded_output, input_tensor.memory_config()); - - return ttnn::reshape(tiled_output, sliced_logical_shape, tiled_output.get_padded_shape()); - } else { - return ttnn::slice( - output_tensors[0], zero_indices, end_indices, step, input_tensor.memory_config(), std::nullopt); - } + else { + working_repetition_vector = std::move(repetition_vector); } - return output_tensors[0]; + + TT_ASSERT(working_tensor.get_logical_volume() == tensor.get_logical_volume()); + TT_ASSERT( + std::accumulate( + working_repetition_vector.cbegin(), + working_repetition_vector.cend(), + 1, + std::multiplies()) == total_reps); + + return std::tie(working_tensor, working_repetition_vector); } +} // namespace detail ttnn::Tensor RepeatOperation::invoke( - const ttnn::Tensor& input_tensor, - const ttnn::Shape& repeat_dims, - const std::optional& memory_config) { - return invoke(DefaultQueueId, input_tensor, repeat_dims, memory_config); + const ttnn::Tensor& tensor, + const ttnn::SmallVector& provided_repetition_vector, + const std::optional& provided_output_mem_config, + uint8_t queue_id) { + auto [working_tensor, repetition_vector] = detail::match_input_rank(tensor, provided_repetition_vector); + MemoryConfig output_mem_config = provided_output_mem_config.value_or(tensor.memory_config()); + auto working_output_mem_config = output_mem_config; + + if (std::any_of(repetition_vector.cbegin(), repetition_vector.cend(), [](auto x) { return x == 0; })) { + const auto& shape = working_tensor.get_logical_shape(); + std::transform( + shape.cbegin(), + shape.cend(), + repetition_vector.cbegin(), + repetition_vector.begin(), + std::multiplies()); + return tensor.reshape(ttnn::Shape(repetition_vector)); + } + + TT_FATAL(working_tensor.get_logical_shape().rank() > 0, "repeat does not support rank 0 tensors"); + + // nothing to do! + if (std::all_of(repetition_vector.cbegin(), repetition_vector.cend(), [](auto x) { return x == 1; })) { + return tensor; + } + + // Sharded -> interleaved + if (tensor.memory_config().is_sharded()) { + auto working_memory_config = tensor.memory_config(); + working_memory_config.memory_layout = TensorMemoryLayout::INTERLEAVED; + working_tensor = ttnn::sharded_to_interleaved(queue_id, tensor, working_memory_config, std::nullopt); + } + if (working_output_mem_config.is_sharded()) { + working_output_mem_config.memory_layout = TensorMemoryLayout::INTERLEAVED; + } + + // tiled -> RM + if (working_tensor.layout() == ttnn::TILE_LAYOUT) { + working_tensor = + ttnn::to_layout(working_tensor, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device*)nullptr); + } + + // loop over dims in repetition vector, backwards because repeat pages first is faster + for (auto it = repetition_vector.crbegin(); it != repetition_vector.crend(); ++it) { + // no op for unit repetitions + if (*it == 1) { + continue; + } + // if last dim + if (it == repetition_vector.crbegin()) { + working_tensor = detail::repeat_last_dim_rm(working_tensor, *it, queue_id, working_output_mem_config); + } + // if not last dim + else { + auto i = repetition_vector.crend() - it - 1; // forward index + working_tensor = detail::repeat_upper_dims_rm(working_tensor, i, *it, queue_id, working_output_mem_config); + } + } + + // RM -> OG page layout + if (tensor.layout() == ttnn::TILE_LAYOUT) { + working_tensor = + ttnn::to_layout(working_tensor, ttnn::TILE_LAYOUT, tensor.get_dtype(), std::nullopt, (Device*)nullptr); + } + + // Interleaved to OG mem layout + if (output_mem_config.is_sharded()) { + working_tensor = ttnn::interleaved_to_sharded(queue_id, working_tensor, output_mem_config, std::nullopt); + } + + return working_tensor; } ttnn::Tensor RepeatOperation::invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& repeat_dims) { - return invoke(DefaultQueueId, input_tensor, repeat_dims, std::nullopt); + return RepeatOperation::invoke( + input_tensor, SmallVector(repeat_dims.cbegin(), repeat_dims.cend()), std::nullopt, DefaultQueueId); } } // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp index 3e35bcdea31..76b780faf2c 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat.hpp @@ -4,7 +4,6 @@ #pragma once -#include "ttnn/run_operation.hpp" #include "ttnn/decorators.hpp" namespace ttnn { @@ -12,22 +11,16 @@ namespace operations::data_movement { struct RepeatOperation { static ttnn::Tensor invoke( - uint8_t queue_id, const ttnn::Tensor& input_tensor, - const ttnn::Shape& repeat_dims, - const std::optional& memory_config_arg); - - static ttnn::Tensor invoke( - const ttnn::Tensor& input_tensor, - const ttnn::Shape& repeat_dims, - const std::optional& memory_config); + const ttnn::SmallVector& repetition_vector, + const std::optional& provided_output_mem_config, + uint8_t queue_id); static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& repeat_dims); }; } // namespace operations::data_movement -constexpr auto repeat = - ttnn::register_operation_with_auto_launch_op<"ttnn::repeat", ttnn::operations::data_movement::RepeatOperation>(); +constexpr auto repeat = ttnn::register_operation<"ttnn::repeat", ttnn::operations::data_movement::RepeatOperation>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp index 9ef0f5f6bc8..e2a3883c737 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/repeat/repeat_pybind.cpp @@ -22,9 +22,9 @@ void bind_repeat(py::module& module, const data_movement_operation_t& operation, ttnn::pybind_overload_t{ [](const data_movement_operation_t& self, const ttnn::Tensor& input_tensor, - const ttnn::Shape& repeat_dims, + const ttnn::SmallVector& repetition_vector, const std::optional& memory_config, - uint8_t queue_id) { return self(queue_id, input_tensor, repeat_dims, memory_config); }, + uint8_t queue_id) { return self(input_tensor, repetition_vector, memory_config, queue_id); }, py::arg("input_tensor"), py::arg("repeat_dims"), py::kw_only(), @@ -42,7 +42,7 @@ void py_bind_repeat(py::module& module) { Args: input_tensor (ttnn.Tensor): the input tensor. - repeat_dims (number): The number of repetitions for each element. + repetition_vector (SmallVector): The number of repetitions for each dimension. Keyword Args: memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. @@ -52,7 +52,7 @@ void py_bind_repeat(py::module& module) { Example: - >>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), 2,)), device) + >>> tensor = ttnn.repeat(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]]), [1,2],)), device) >>> print(tensor) tensor([[1, 2], [1, 2], diff --git a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp index 95ccdf41ade..bf7422ba621 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/reshape_on_device/reshape.cpp @@ -48,9 +48,9 @@ static Tensor manual_insertion( logical_shape, TensorLayout::fromPaddedShape( DataType::BFLOAT16, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(Layout::ROW_MAJOR); + .to_layout(Layout::ROW_MAJOR); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp index c132e643ad5..3a832e944d6 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reader_unary_stick_layout_sharded_blocks_interleaved_start_id.cpp @@ -54,6 +54,7 @@ void kernel_main() { noc_async_read(src_noc_addr, scratch_l1_write_addr, aligned_block_width_bytes); noc_async_read_barrier(); noc_async_read(scratch_l1_noc_read_addr, l1_write_addr, block_width_bytes); + noc_async_read_barrier(); stick_id++; l1_write_addr += padded_block_width_bytes; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_width_reader.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_width_reader.cpp index 835773e8a0a..d6141524051 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_width_reader.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/dataflow/reshard_same_width_reader.cpp @@ -6,8 +6,7 @@ #include "dataflow_api.h" void kernel_main() { - - constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0); + constexpr uint32_t shard_cb_id = get_compile_time_arg_val(0); constexpr bool read_from_dram = get_compile_time_arg_val(1); uint32_t src_addr = get_arg_val(0); diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp index 4d4a9981132..913dc4cc97b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/device/interleaved_to_sharded_program_factory.cpp @@ -18,7 +18,7 @@ namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( const Tensor& input, const Tensor& output, bool keep_l1_aligned, uint32_t num_slices, uint32_t slice_index) { tt::tt_metal::Program program{}; - + keep_l1_aligned = true; uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width, num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_per_shard_height_last, num_units_per_shard_width_last, padded_offset_bytes; @@ -45,6 +45,12 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( num_units = input.volume() / TILE_HW; input_unit_size = tt::tt_metal::detail::TileSize(input_cb_data_format); output_unit_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + TT_FATAL( + shard_spec.shape[0] % TILE_HEIGHT == 0 && shard_spec.shape[1] % TILE_WIDTH == 0, + "Shard shape {} must be tile {}x{} sized!", + shard_spec.shape, + TILE_HEIGHT, + TILE_WIDTH); num_units_per_shard_height = shard_spec.shape[0] / TILE_HEIGHT; num_units_per_shard_width = shard_spec.shape[1] / TILE_WIDTH; num_units_per_shard = num_units_per_shard_height * num_units_per_shard_width; @@ -81,7 +87,6 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( } } - auto all_cores = shard_spec.grid; uint32_t input_cb_index = tt::CBIndex::c_0; uint32_t scratch_cb_index = tt::CBIndex::c_1; @@ -252,9 +257,9 @@ operation::ProgramWithCallbacks interleaved_to_sharded_multi_core( uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - bool aligned = (src_is_dram ? curr_idx_w % dram_alignment == 0 : true); + bool aligned = (src_is_dram ? (curr_idx_w % dram_alignment == 0) && (padded_offset_bytes % dram_alignment == 0) : true); //for blackhole and keep_l1_aligned cases, always enforce unaligned kernel call - aligned = aligned and !(is_blackhole) and !(keep_l1_aligned); + aligned = aligned and !(is_blackhole); uint32_t aligned_width_offset, aligned_shard_width, aligned_offset; if (!aligned) { //TODO: is this right, leaving non BH case the same for now, should investigate diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp index 71f59191890..27b8b185d43 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp @@ -338,16 +338,10 @@ operation::ProgramWithCallbacks reshard_multi_core_same_width(const Tensor& inpu bool interface_with_dram = (remote_core_type == CoreType::DRAM); tt::tt_metal::KernelHandle kernel_id_0 = tt::tt_metal::CreateKernel( - program, - kernel_name, - all_cores, - tt::tt_metal::ReaderDataMovementConfig({cb_index, interface_with_dram})); + program, kernel_name, all_cores, tt::tt_metal::ReaderDataMovementConfig({cb_index, interface_with_dram})); tt::tt_metal::KernelHandle kernel_id_1 = tt::tt_metal::CreateKernel( - program, - kernel_name, - all_cores, - tt::tt_metal::WriterDataMovementConfig({cb_index, interface_with_dram})); + program, kernel_name, all_cores, tt::tt_metal::WriterDataMovementConfig({cb_index, interface_with_dram})); tt::tt_metal::CircularBufferConfig cb_config = tt::tt_metal::CircularBufferConfig(total_size, {{cb_index, data_format}}) diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp index 1f4a486caad..d173898bd14 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/device/sharded_to_interleaved_program_factory.cpp @@ -18,7 +18,7 @@ namespace ttnn::operations::data_movement::detail { operation::ProgramWithCallbacks sharded_to_interleaved_multi_core( const Tensor& input, const Tensor& output, bool is_l1_aligned, uint32_t num_slices, uint32_t slice_index) { tt_metal::Program program{}; - + is_l1_aligned = true; uint32_t num_units, num_units_per_shard, input_unit_size, output_unit_size, num_units_per_shard_width, num_units_per_shard_height, num_units_offset, num_units_per_row, num_units_height, num_units_per_shard_height_last, num_units_per_shard_width_last; diff --git a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp index 7eb540b843f..886b2ac5b33 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/split/split.cpp @@ -75,10 +75,10 @@ std::vector split_dim_n_chunks_rm( output_chunk = output_chunk.pad_to_tile(0.0); } - output_chunk = output_chunk.to(layout); + output_chunk = output_chunk.to_layout(layout); if (device) { - output_chunk = output_chunk.to(*device); + output_chunk = output_chunk.to_device(*device); } output_tensors.push_back(output_chunk); diff --git a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp index bc8298d5295..05d3356f383 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/transpose/device/transpose_program_factory.cpp @@ -675,8 +675,7 @@ operation::ProgramWithCallbacks transpose_hc_multi_core( // TODO: noc_async_write only require 16B alignment for both DRAM and L1 for Blackhole, so instead of reading in // face-lines from C tiles to form a single tile, we can load a single tile and then write out its face-lines to C // tiles - uint32_t alignment = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? hal::get_dram_alignment() - : hal::get_l1_alignment(); + uint32_t alignment = dst_buffer->alignment(); bool misaligned = alignment > sub_tile_line_bytes; if (row_major) { auto num_sticks = num_tiles_per_core_group_1 > num_tiles_per_core_group_2 ? num_tiles_per_core_group_1 diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize_w.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize_w.cpp new file mode 100644 index 00000000000..11aa2193c94 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize_w.cpp @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "compute_kernel_api/untilize.h" +#include "debug/dprint.h" + +namespace NAMESPACE { +void MAIN { + uint32_t per_core_block_cnt = get_compile_time_arg_val(0); + uint32_t per_core_block_tile_cnt = get_compile_time_arg_val(1); + uint32_t third_dim = get_compile_time_arg_val(2); + untilize_init(tt::CBIndex::c_0, tt::CBIndex::c_16); + + uint32_t onetile = 1; + for (uint32_t b = 0; b < per_core_block_cnt * per_core_block_tile_cnt * third_dim; ++b) { + cb_wait_front(tt::CBIndex::c_0, onetile); + cb_reserve_back(tt::CBIndex::c_16, onetile); + + untilize_block(tt::CBIndex::c_0, onetile, tt::CBIndex::c_16); + + cb_push_back(tt::CBIndex::c_16, onetile); + cb_pop_front(tt::CBIndex::c_0, onetile); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_stick_layout_col_multicore.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_stick_layout_col_multicore.cpp new file mode 100644 index 00000000000..1cdb04d63ce --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/writer_unary_stick_layout_col_multicore.cpp @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void kernel_main() { + constexpr uint32_t cb_id_out0 = 16; + + const uint32_t total_num_rows = get_compile_time_arg_val(3); + const uint32_t ncores = get_compile_time_arg_val(4); + const uint32_t third_dim = get_compile_time_arg_val(5); + const uint32_t tile_width = get_compile_time_arg_val(6); + + const uint32_t dst_addr = get_arg_val(0); + const uint32_t unpadded_X_size = get_arg_val(1); + const uint32_t core_number = get_arg_val(2); + + constexpr bool dst0_is_dram = get_compile_time_arg_val(0) == 1; + +#define stick_size_is_pow2 get_compile_time_arg_val(1) == 1 +#if (stick_size_is_pow2) + constexpr uint32_t log_base_2_of_page_size = get_compile_time_arg_val(2); + const InterleavedPow2AddrGen s = { + .bank_base_address = dst_addr, .log_base_2_of_page_size = log_base_2_of_page_size}; +#else + const InterleavedAddrGen s = {.bank_base_address = dst_addr, .page_size = unpadded_X_size}; +#endif + + auto write_block = [&](uint32_t num_rows, + uint32_t mul, + uint32_t size_per_row_per_block, + uint32_t start_id, + uint32_t width_size, + uint32_t size_2d) { + uint32_t onetile = 1; + bool has_rows = (num_rows) > 0; + + cb_wait_front(cb_id_out0, onetile * has_rows); + uint32_t l1_read_addr = get_write_ptr(cb_id_out0); + + for (uint32_t k = 0; k < num_rows; k++) { + uint64_t dst_noc_addr = get_noc_addr(size_2d + k, s); + + uint32_t total_size = mul * size_per_row_per_block + start_id + width_size; + uint32_t padded_size = total_size - unpadded_X_size; + uint32_t write_size = width_size; + + if (mul == ncores - 1 && padded_size > 0) { + write_size = width_size - padded_size; + } + + noc_async_write(l1_read_addr, dst_noc_addr + start_id + mul * size_per_row_per_block, write_size); + + noc_async_write_barrier(); + + if (k > 0 && (k % tile_width == 0)) { + cb_pop_front(cb_id_out0, onetile * has_rows); + cb_wait_front(cb_id_out0, onetile * has_rows); + } + l1_read_addr += width_size; + } + + cb_pop_front(cb_id_out0, onetile * has_rows); + }; + + const uint32_t size_per_row_per_block = get_arg_val(3); + const uint32_t blocks_per_core = get_arg_val(4); + const uint32_t width_size = get_arg_val(5); + + uint32_t size_2d = 0; + for (uint32_t dim3 = 0; dim3 < third_dim; dim3++) { + uint32_t start_id = 0; + for (uint32_t b = 0; b < blocks_per_core; b++) { + write_block(total_num_rows, core_number, size_per_row_per_block, start_id, width_size, size_2d); + start_id += width_size; + } + size_2d += total_num_rows; + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp index 06cab19a567..fb9e98524df 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/untilize_with_unpadding_program_factory.cpp @@ -206,6 +206,164 @@ operation::ProgramWithCallbacks untilize_with_unpadding_single_core( return {std::move(program), override_runtime_args_callback}; } +operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_col_interleaved( + const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { + tt::tt_metal::Program program{}; + + tt::DataFormat input_cb_data_format = datatype_to_dataformat_converter(a.get_dtype()); + uint32_t input_single_tile_size = tt::tt_metal::detail::TileSize(input_cb_data_format); + tt::DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); + uint32_t output_single_tile_size = tt::tt_metal::detail::TileSize(output_cb_data_format); + + const auto& input_shape = a.get_padded_shape(); + const auto& output_shape = output.get_padded_shape(); + + IDevice* device = a.device(); + CoreCoord grid_size = device->compute_with_storage_grid_size(); + + uint32_t num_blocks = input_shape[-1] / TILE_WIDTH; + uint32_t num_tiles_per_row = a.get_padded_shape()[-1] / TILE_WIDTH; + uint32_t num_tiles_per_col = a.get_padded_shape()[-2] / TILE_HEIGHT; + + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = + ttnn::split_blocks_for_tilize(grid_size, num_blocks); + + bool has_cliff = core_range_cliff.size() > 0; + + uint32_t padded_row_size_bytes; + uint32_t unpadded_row_size_bytes; + + uint32_t el_size = a.element_size(); + if (a.get_dtype() == DataType::BFLOAT8_B) { + padded_row_size_bytes = input_shape[-1] * output.element_size(); + unpadded_row_size_bytes = output_shape[-1] * output.element_size(); + el_size = output.element_size(); + } else { + padded_row_size_bytes = input_shape[-1] * a.element_size(); + unpadded_row_size_bytes = output_shape[-1] * a.element_size(); + } + + create_cb(tt::CBIndex::c_0, program, all_cores, input_single_tile_size, num_tiles_per_col, input_cb_data_format); + create_cb(tt::CBIndex::c_16, program, all_cores, output_single_tile_size, num_tiles_per_col, output_cb_data_format); + + Buffer* src0_buffer = a.buffer(); + Buffer* dst_buffer = output.buffer(); + TT_FATAL(dst_buffer != nullptr, "Output buffer should be allocated on device!"); + + // reader + + uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + uint32_t num_tiles_2d = a.get_padded_shape()[-1] * a.get_padded_shape()[-2] / TILE_HW; + + auto log_shape = output.get_logical_shape(); + uint32_t third_dim = 1; + if (log_shape.rank() == 3) { + third_dim = log_shape[-3]; + } else if (log_shape.rank() >= 4) { + third_dim = log_shape[-3] * log_shape[-4]; + } + + KernelHandle unary_reader_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_col_multicore.cpp", + all_cores, + ReaderDataMovementConfig({src0_is_dram, num_tiles_2d, third_dim, nblocks_per_core})); + + // writer + + uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + uint32_t stick_size = unpadded_row_size_bytes; + uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); + uint32_t log2_stick_size = stick_size_is_power_of_two ? (std::uint32_t)std::log2(stick_size) : 0; + + uint32_t total_num_rows = output.get_logical_shape()[-2]; + + KernelHandle unary_writer_kernel_id = CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/untilize_with_unpadding/device/kernels/dataflow/" + "writer_unary_stick_layout_col_multicore.cpp", + all_cores, + WriterDataMovementConfig( + {out_is_dram, stick_size_is_power_of_two, log2_stick_size, total_num_rows, ncores, third_dim, TILE_WIDTH})); + + // compute + + std::string compute_kernel("ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/untilize_w.cpp"); + + if (core_range.size() > 0) { + auto tilize_kernel_id = CreateKernel( + program, + compute_kernel, + core_range, + ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .compile_args = {nblocks_per_core, num_tiles_per_col, third_dim}}); + } + if (has_cliff) { + auto tilize_cliff_kernel_id = CreateKernel( + program, + compute_kernel, + core_range_cliff, + ComputeConfig{ + .fp32_dest_acc_en = fp32_dest_acc_en, + .compile_args = {nblocks_per_core_cliff, num_tiles_per_col, third_dim}}); + } + + // RUNTIME ARGS + const auto& cores = grid_to_cores(ncores, grid_size.x, grid_size.y, true); + uint32_t number_blocks_per_core; + for (uint32_t i = 0; i < ncores; ++i) { + const auto& core = cores[i]; + + if (has_cliff && i == ncores - 1) { + number_blocks_per_core = nblocks_per_core_cliff; + } else { + number_blocks_per_core = nblocks_per_core; + } + uint32_t size_per_row_per_block = nblocks_per_core * TILE_WIDTH * el_size; + + // writer runtime args + std::vector writer_rt_args = { + dst_buffer->address(), + unpadded_row_size_bytes, + i, + size_per_row_per_block, + number_blocks_per_core, + TILE_WIDTH * el_size, + }; + + // reader runtime args + const std::array reader_rt_args = {src0_buffer->address(), i, num_tiles_per_row, number_blocks_per_core}; + SetRuntimeArgs(program, unary_reader_kernel_id, core, reader_rt_args); + SetRuntimeArgs(program, unary_writer_kernel_id, core, writer_rt_args); + } + + auto override_runtime_args_callback = + [reader_kernel_id = unary_reader_kernel_id, writer_kernel_id = unary_writer_kernel_id, cores = cores]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) { + auto src_buffer = input_buffers.at(0); + auto dst_buffer = output_buffers.at(0); + + auto& reader_runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id); + auto& writer_runtime_args_by_core = GetRuntimeArgs(program, writer_kernel_id); + + for (const auto& core : cores) { + { + auto& runtime_args = reader_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = src_buffer->address(); + } + { + auto& runtime_args = writer_runtime_args_by_core[core.x][core.y]; + runtime_args[0] = dst_buffer->address(); + } + } + }; + + return {std::move(program), override_runtime_args_callback}; +} + operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( const Tensor& a, Tensor& output, bool use_pack_untilize, bool fp32_dest_acc_en) { tt::tt_metal::Program program{}; @@ -224,6 +382,11 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( uint32_t num_blocks = input_shape[-1] == 0 ? 0 : a.volume() / input_shape[-1] / TILE_HEIGHT; uint32_t num_tiles_per_row = a.get_padded_shape()[-1] / TILE_WIDTH; + uint32_t num_tiles_per_col = a.get_padded_shape()[-2] / TILE_HEIGHT; + if (num_tiles_per_row > num_tiles_per_col) { + return untilize_with_unpadding_multi_core_col_interleaved(a, output, use_pack_untilize, fp32_dest_acc_en); + } + auto [ncores, all_cores, core_range, core_range_cliff, nblocks_per_core, nblocks_per_core_cliff] = ttnn::split_blocks_for_tilize(grid_size, num_blocks); @@ -249,6 +412,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( /** reader */ + uint32_t src0_is_dram = src0_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; KernelHandle unary_reader_kernel_id = CreateKernel( @@ -259,6 +423,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( /** writer */ + uint32_t out_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; uint32_t stick_size = unpadded_row_size_bytes; uint32_t stick_size_is_power_of_two = is_power_of_two_at_least_32(stick_size); @@ -278,6 +443,7 @@ operation::ProgramWithCallbacks untilize_with_unpadding_multi_core_interleaved( /** compute */ + std::string compute_kernel( "ttnn/cpp/ttnn/operations/data_movement/untilize/device/kernels/compute/pack_untilize.cpp"); if (num_tiles_per_row > MAX_PACK_UNTILIZE_WIDTH || !use_pack_untilize || a.get_dtype() == DataType::UINT16) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp index 3160d25f863..c9cf7a70c6e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_kernel.cpp @@ -34,7 +34,7 @@ void MAIN { binary_op_init_common(cb_inp0, cb_inp1, cb_out0); #if not PRE_SCALE - binary_op_specific_init(); + binary_op_specific_init(cb_inp0, cb_inp1); #endif #ifdef PACK_RELU @@ -110,7 +110,7 @@ void MAIN { cb_reserve_back(cb_out0, per_core_block_size); #if PRE_SCALE - binary_op_specific_init(); + binary_op_specific_init(cb_inp0, cb_inp1); #endif tile_regs_acquire(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp index 7d2bde98bf5..9e56c7fccfd 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary.cpp @@ -44,7 +44,7 @@ ALWI void process_tile( cb_reserve_back(cb_out, onetile); #if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) - binary_op_specific_init(); + binary_op_specific_init(cb_post_lhs, cb_post_rhs); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); @@ -83,7 +83,7 @@ void MAIN { #endif #if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) - binary_op_specific_init(); + binary_op_specific_init(cb_post_lhs, cb_post_rhs); #endif uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp index 30ff24df71d..069366f4bc3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_no_bcast.cpp @@ -27,7 +27,7 @@ void MAIN { #endif #if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) - binary_op_specific_init(); + binary_op_specific_init(cb_post_lhs, cb_post_rhs); #endif constexpr uint32_t onetile = 1; @@ -42,7 +42,7 @@ void MAIN { cb_reserve_back(cb_out, onetile); #if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) - binary_op_specific_init(); + binary_op_specific_init(cb_post_lhs, cb_post_rhs); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp index db05ebae73a..8c1b150c12d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_binary_scalar.cpp @@ -26,7 +26,7 @@ void MAIN { #endif #if not(HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS)) - binary_op_specific_init(); + binary_op_specific_init(cb_post_lhs, cb_post_rhs); #endif constexpr uint32_t onetile = 1; @@ -41,7 +41,7 @@ void MAIN { cb_reserve_back(cb_out, onetile); #if HAS_ACTIVATIONS(LHS) or HAS_ACTIVATIONS(RHS) - binary_op_specific_init(); + binary_op_specific_init(cb_post_lhs, cb_post_rhs); #endif tile_regs_acquire(); BINARY_OP(cb_post_lhs, cb_post_rhs, 0, 0, 0); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_col_multicore.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_col_multicore.cpp new file mode 100644 index 00000000000..77a65709e30 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/reader_unary_interleaved_col_multicore.cpp @@ -0,0 +1,57 @@ + +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t src_addr = get_arg_val(0); + uint32_t core_number = get_arg_val(1); + uint32_t tiles_per_row = get_arg_val(2); + uint32_t num_blocks = get_arg_val(3); + + constexpr uint32_t cb_id_in0 = 0; + constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; + const uint32_t num_tiles_per_2d = get_compile_time_arg_val(1); + const uint32_t third_dim = get_compile_time_arg_val(2); + const uint32_t number_blocks_per_core = get_compile_time_arg_val(3); + +#ifdef OUT_SHARDED + cb_wait_front(cb_id_in0, onetile); +#else + + // single-tile ublocks + constexpr uint32_t onetile = 1; + const uint32_t tile_bytes = get_tile_size(cb_id_in0); + const DataFormat data_format = get_dataformat(cb_id_in0); + + const InterleavedAddrGenFast s = { + .bank_base_address = src_addr, .page_size = tile_bytes, .data_format = data_format}; + +#ifdef BACKWARDS + uint32_t end_id = -num_tiles_per_2d; + for (uint32_t dim = 0; dim > -third_dim; dim--) { + for (uint32_t k = 0; k > -num_blocks; k--) { + for (uint32_t i = num_tiles_per_2d * dim - number_blocks_per_core * core_number; + i > end_id + num_tiles_per_2d * dim; + i = i - tiles_per_row) { +#else + uint32_t end_id = num_tiles_per_2d; + for (uint32_t dim = 0; dim < third_dim; dim++) { + for (uint32_t k = 0; k < num_blocks; k++) { + for (uint32_t i = num_tiles_per_2d * dim + number_blocks_per_core * core_number; + i < end_id + num_tiles_per_2d * dim; + i = i + tiles_per_row) { +#endif + cb_reserve_back(cb_id_in0, onetile); + uint32_t l1_write_addr = get_write_ptr(cb_id_in0); + noc_async_read_tile(i + k, s, l1_write_addr); + + noc_async_read_barrier(); + cb_push_back(cb_id_in0, onetile); + } + } + } +#endif +} diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 9c24d12d81c..92461a79793 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -828,13 +828,24 @@ Tensor _logit(const Tensor& input_a, float eps, const std::optional ExecuteUnaryBackwardProd::invoke( if (updated_grad.storage_type() != StorageType::DEVICE && updated_grad.storage_type() != StorageType::MULTI_DEVICE) { Tensor pad_updated_grad = updated_grad.pad_to_tile(1.0f); - pad_updated_grad = pad_updated_grad.to(Layout::TILE); - updated_grad = pad_updated_grad.to(input.device()); + pad_updated_grad = pad_updated_grad.to_layout(Layout::TILE); + updated_grad = pad_updated_grad.to_device(input.device()); } } else if (dim == 2 || dim == -2) { ttnn::SmallVector after_permute_dims = {0, 2, 1, 3}; diff --git a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp index 305dacf95d1..5f60337d2ab 100644 --- a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp @@ -129,7 +129,7 @@ Tensor AutoFormat::format_input_tensor( // Host side conversions if (pad_input) { if (formatted_input.get_layout() != Layout::ROW_MAJOR) { - formatted_input = formatted_input.to(Layout::ROW_MAJOR); + formatted_input = formatted_input.to_layout(Layout::ROW_MAJOR); convert_layout = formatted_input.get_layout() != target_layout; } formatted_input = ttnn::pad( @@ -140,7 +140,7 @@ Tensor AutoFormat::format_input_tensor( } if (convert_layout) { - formatted_input = formatted_input.to(target_layout); + formatted_input = formatted_input.to_layout(target_layout); } return AutoFormat::move_tensor_to_device(formatted_input, device, mem_config); @@ -225,7 +225,7 @@ Tensor AutoFormat::format_output_tensor( if (unpad_output) { // Requires RM for unpad if (formatted_output.get_layout() != Layout::ROW_MAJOR) { - formatted_output = formatted_output.to(Layout::ROW_MAJOR); + formatted_output = formatted_output.to_layout(Layout::ROW_MAJOR); convert_layout = formatted_output.get_layout() != target_layout; } auto begins = std::array({0, 0, 0, 0}); @@ -238,10 +238,10 @@ Tensor AutoFormat::format_output_tensor( // Default to RM layout if we can't match the formatted_input layout if (target_layout == Layout::TILE && !AutoFormat::legal_tile_shape(formatted_output.get_padded_shape())) { if (formatted_output.get_layout() != Layout::ROW_MAJOR) { - formatted_output = formatted_output.to(Layout::ROW_MAJOR); + formatted_output = formatted_output.to_layout(Layout::ROW_MAJOR); } } else { - formatted_output = formatted_output.to(target_layout); + formatted_output = formatted_output.to_layout(target_layout); } } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt index b84c8ddd5e3..e80883ac5f7 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/CMakeLists.txt @@ -16,6 +16,7 @@ set(CCL_EXPERIMENTAL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/all_gather_async_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_program.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/all_gather_async/device/all_gather_async_program_minimal_variants.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce_async/all_reduce_async.cpp ${CMAKE_CURRENT_SOURCE_DIR}/all_reduce_async/all_reduce_async_pybind.cpp CACHE INTERNAL diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index 13f1979d54d..f295d317f64 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -107,29 +107,166 @@ std::vector AllGatherAsync::compute_output_specs(const std::ve TensorLayout(input_tensor.get_dtype(), input_tensor.get_tensor_spec().page_config(), output_mem_config))}; } +AllGatherAsyncVersion AllGatherAsync::select_version(const Tensor& input_tensor) const { + auto input_tensor_shape = input_tensor.get_padded_shape(); + auto input_tensor_buffer_layout = input_tensor.buffer()->buffer_layout(); + auto input_tensor_page_layout = input_tensor.layout(); + auto input_tensor_memory_config = input_tensor.memory_config(); + bool input_is_sharded = input_tensor_memory_config.shard_spec.has_value(); + bool output_is_sharded = output_mem_config.shard_spec.has_value(); + uint32_t input_shard_num_cores = 0; + uint32_t output_shard_num_cores = 0; + if (input_is_sharded) { + input_shard_num_cores = input_tensor_memory_config.shard_spec->grid.num_cores(); + log_trace( + tt::LogOp, + "[select_version] input_tensor_memory_config.shard_spec->shape: {}", + input_tensor_memory_config.shard_spec->shape); + } + if (output_is_sharded) { + output_shard_num_cores = output_mem_config.shard_spec->grid.num_cores(); + log_trace(tt::LogOp, "[select_version] output_mem_config.shard_spec->shape: {}", output_mem_config.shard_spec->shape); + } + + log_trace(tt::LogOp, "[select_version] input_tensor_shape: {}", input_tensor_shape); + log_trace(tt::LogOp, "[select_version] input_tensor_memory_config: {}", input_tensor_memory_config); + log_trace(tt::LogOp, "[select_version] output_mem_config: {}", output_mem_config); + log_trace(tt::LogOp, "[select_version] input_shard_num_cores: {}", input_shard_num_cores); + log_trace(tt::LogOp, "[select_version] output_shard_num_cores: {}", output_shard_num_cores); + + // Check for minimal interleaved case + if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 1 && input_tensor_shape[2] == 32 && + input_tensor_buffer_layout == tt::tt_metal::TensorMemoryLayout::INTERLEAVED && + input_tensor_page_layout == tt::tt_metal::Layout::TILE && this->enable_persistent_fabric_mode) { + return AllGatherAsyncVersion::MINIMAL_INTERLEAVED_32; + } + + log_trace(tt::LogOp, "[select_version] input_is_sharded: {}", input_is_sharded); + log_trace(tt::LogOp, "[select_version] output_is_sharded: {}", output_is_sharded); + + if (input_is_sharded && output_is_sharded) { + // Check for first llama post binary matmul case + if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 1 && input_tensor_shape[2] == 32 && + input_tensor_shape[3] == 960 && input_tensor_memory_config.buffer_type == BufferType::L1 && + output_mem_config.buffer_type == BufferType::L1 && + input_tensor_memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED && + input_tensor_memory_config.shard_spec->shape[0] == 32 && + input_tensor_memory_config.shard_spec->shape[1] == 32 && + output_mem_config.shard_spec->shape[0] == 32 && + output_mem_config.shard_spec->shape[1] == 160 && input_shard_num_cores == 30 && + output_shard_num_cores == 24) { + return AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL; + } + + // Check for second llama post binary matmul case + if (input_tensor_shape[0] == 1 && input_tensor_shape[1] == 8 && input_tensor_shape[2] == 32 && + input_tensor_shape[3] == 128 && input_tensor_memory_config.buffer_type == BufferType::L1 && + output_mem_config.buffer_type == BufferType::L1 && + input_tensor_memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED && + output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED && + input_tensor_memory_config.shard_spec->shape[0] == 32 && + input_tensor_memory_config.shard_spec->shape[1] == 128 && + output_mem_config.shard_spec->shape[0] == 32 && + output_mem_config.shard_spec->shape[1] == 128 && input_shard_num_cores == 8 && + output_shard_num_cores == 32) { + log_trace(tt::LogOp, "All conditions matched for LLAMA_POST_BINARY_MATMUL case"); + return AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL; + } + } + log_trace(tt::LogOp, "All conditions matched for generic case"); + return AllGatherAsyncVersion::GENERIC; +} + operation::ProgramWithCallbacks AllGatherAsync::create_program( const std::vector& input_tensors, std::vector& output_tensors) const { tt::log_debug(tt::LogOp, "DEBUG: create_program is called"); - return all_gather_async_multi_core_with_workers( - input_tensors[0], - this->forward_device, - this->backward_device, - output_tensors[0], - this->dim, - this->num_links, - this->ring_size, - this->ring_index, - this->topology, - this->semaphore, - this->sub_device_id, - this->enable_persistent_fabric_mode); + + AllGatherAsyncVersion version = select_version(input_tensors[0]); + + log_trace(tt::LogOp, "version: {}", static_cast(version)); + + switch (version) { + case AllGatherAsyncVersion::MINIMAL_INTERLEAVED_32: + log_trace( + tt::LogOp, + "Detected all gather specialized shape. all_gather_async_minimal_interleaved_dim3_1_1_32_any is " + "called"); + return all_gather_async_minimal_interleaved_dim3_1_1_32_any( + input_tensors[0], + this->forward_device, + this->backward_device, + output_tensors[0], + this->dim, + this->num_links, + this->ring_size, + this->ring_index, + this->topology, + this->semaphore, + this->sub_device_id, + this->enable_persistent_fabric_mode); + + case AllGatherAsyncVersion::LLAMA_POST_BINARY_MATMUL: + log_trace( + tt::LogOp, + "Detected all gather specialized shape. all_gather_async_llama_post_binary_matmul is called"); + return all_gather_async_llama_post_binary_matmul( + input_tensors[0], + this->forward_device, + this->backward_device, + output_tensors[0], + this->dim, + this->num_links, + this->ring_size, + this->ring_index, + this->topology, + this->semaphore, + this->sub_device_id, + this->enable_persistent_fabric_mode); + + case AllGatherAsyncVersion::GENERIC: + default: + log_trace(tt::LogOp, "Running generic all_gather_async_multi_core_with_workers"); + return all_gather_async_multi_core_with_workers( + input_tensors[0], + this->forward_device, + this->backward_device, + output_tensors[0], + this->dim, + this->num_links, + this->ring_size, + this->ring_index, + this->topology, + this->semaphore, + this->sub_device_id, + this->enable_persistent_fabric_mode); + } } const operation::Hash AllGatherAsync::compute_program_hash(const std::vector& input_tensors) const { + log_trace(tt::LogOp, "compute_program_hash is called"); + AllGatherAsyncVersion version = select_version(input_tensors[0]); + log_trace(tt::LogOp, "version: {}", static_cast(version)); auto input_shape = input_tensors[0].get_padded_shape(); auto input_memory_layout = input_tensors[0].get_layout(); auto input_dtype = input_tensors[0].get_dtype(); auto input_memory_config = input_tensors[0].memory_config(); + if (version == AllGatherAsyncVersion::GENERIC) { + // Generic version should hash semaphore address as well + uint32_t semaphore_address = this->semaphore.address(); + return operation::hash_operation( + this->dim, + this->num_links, + this->ring_size, + this->ring_index, + this->output_mem_config, + this->topology, + input_shape, + input_memory_layout, + input_dtype, + input_memory_config, + semaphore_address); + } return operation::hash_operation( this->dim, this->num_links, @@ -143,8 +280,6 @@ const operation::Hash AllGatherAsync::compute_program_hash(const std::vector forward_device; std::optional backward_device; @@ -83,6 +89,8 @@ struct AllGatherAsync { operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, std::vector& output_tensors) const; const operation::Hash compute_program_hash(const std::vector& input_tensors) const; + + AllGatherAsyncVersion select_version(const Tensor& input_tensor) const; }; namespace ccl { @@ -101,6 +109,8 @@ AllGatherAsync create_all_gather_async_struct( } // namespace ccl // All Gather Variants +std::tuple> choose_worker_cores( + size_t num_links, size_t num_workers_per_link, bool persistent_fabric_mode, IDevice* device, const std::optional& sub_device_id); operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const Tensor& input_tensor, std::optional forward_device, @@ -114,6 +124,32 @@ operation::ProgramWithCallbacks all_gather_async_multi_core_with_workers( const GlobalSemaphore semaphore, const std::optional& sub_device_id, bool enable_persistent_fabric_mode); +operation::ProgramWithCallbacks all_gather_async_minimal_interleaved_dim3_1_1_32_any( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore& semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode); +operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore& semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode); namespace operations { namespace experimental { diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp new file mode 100644 index 00000000000..ba8edc57bf6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_program_minimal_variants.cpp @@ -0,0 +1,561 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +/// +#include + +#include +#include +#include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp" +#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include "ttnn/operations/ccl/ccl_host_datastructures.hpp" +#include "ttnn/operations/ccl/ccl_common.hpp" +#include "ttnn/operations/math.hpp" +#include +#include +#include +#include +#include "cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp" +#include "cpp/ttnn/operations/ccl/common/host/ccl_command_stream_builders.hpp" + +#include "cpp/ttnn/operations/ccl/common/uops/command_lowering.hpp" + +#include "cpp/ttnn/operations/ccl/common/host/ccl_worker_builder.hpp" +#include "cpp/ttnn/operations/ccl/common/host/command_backend_runtime_args_overrider.hpp" +#include +#include +#include +#include +using namespace tt::constants; + +namespace ttnn { + +using namespace ccl; + +void append_fabric_connection_rt_args( + const std::optional& connection, + const CoreCoord& core, + tt::tt_metal::Program& program, + std::vector& writer_rt_args) { + writer_rt_args.push_back(connection.has_value()); + if (connection.has_value()) { + auto sender_worker_flow_control_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_teardown_semaphore_id = CreateSemaphore(program, {core}, 0); + auto sender_worker_buffer_index_semaphore_id = CreateSemaphore(program, {core}, 0); + append_worker_to_fabric_edm_sender_rt_args( + connection.value(), + sender_worker_flow_control_semaphore_id, + sender_worker_teardown_semaphore_id, + sender_worker_buffer_index_semaphore_id, + writer_rt_args); + } +} + +operation::ProgramWithCallbacks all_gather_async_minimal_interleaved_dim3_1_1_32_any( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore& semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode) { + tt::tt_metal::Program program{}; + const bool enable_async_output_tensor = false; + TT_FATAL( + enable_persistent_fabric_mode, + "only persistent fabric mode is supported for all_gather_async_minimal_interleaved_dim3_1_1_32_any"); + + IDevice* device = input_tensor.device(); + bool is_first_chip = ring_index == 0; + bool is_last_chip = ring_index == ring_size - 1; + log_trace( + tt::LogOp, + "DEBUG: device: {}, is_first_chip: {}, is_last_chip: {}", + input_tensor.device()->id(), + is_first_chip, + is_last_chip); + + std::optional local_fabric_handle = + ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + &program, + enable_persistent_fabric_mode, + num_links); + + // Get OP Config, topology config + std::vector input_tensors = {input_tensor}; + std::vector output_tensors = {output_tensor}; + const auto& op_config = ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + LineTopology line_topology(ring_size, ring_index); + const size_t num_targets_forward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); + const size_t num_targets_backward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); + + // Get worker cores, assuming 1 worker per link + uint32_t num_workers_per_link = 1; + const auto [sender_worker_core_range, sender_worker_cores] = + choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device, sub_device_id); + + // L1 Scratch CB Creation + const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes(); + uint32_t l1_scratch_cb_page_size_bytes = op_config.get_page_size(); + uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; + uint32_t cb_num_pages = 3 * num_pages_per_packet; // tripple buffering + uint32_t src0_cb_index = tt::CB::c_in0; + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_num_pages * l1_scratch_cb_page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, l1_scratch_cb_page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); + // Set aside a buffer we can use for storing packet headers in (particularly for atomic incs) + const auto reserved_packet_header_CB_index = tt::CB::c_in1; + static constexpr auto num_packet_headers_storable = 8; + static constexpr auto packet_header_size_bytes = sizeof(tt::fabric::PacketHeader); + tt::tt_metal::CircularBufferConfig cb_reserved_packet_header_config = + tt::tt_metal::CircularBufferConfig( + num_packet_headers_storable * packet_header_size_bytes * 2, + {{reserved_packet_header_CB_index, tt::DataFormat::RawUInt32}}) + .set_page_size(reserved_packet_header_CB_index, packet_header_size_bytes); + auto reserved_packet_header_CB_handle = + CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); + + // Tensor Info + const auto input_tensor_layout = input_tensor.buffer()->buffer_layout(); + const auto input_tensor_buffer_type = input_tensor.buffer()->buffer_type(); + const auto input_tensor_page_layout = input_tensor.layout(); + const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); + const auto output_tensor_layout = output_tensor.buffer()->buffer_layout(); + const auto output_tensor_buffer_type = output_tensor.buffer()->buffer_type(); + const auto output_tensor_page_layout = output_tensor.layout(); + + // KERNEL CREATION + // Reader + auto reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; + reader_kernel_config.compile_args = { + ring_index, // my_chip_id + static_cast(input_tensor_buffer_type), // buffer0_type + src0_cb_index, // cb0_id + num_pages_per_packet, // packet_size_in_pages + op_config.get_page_size(), // tensor0_page_size + }; + log_trace(tt::LogOp, "Reader Compile Args:"); + for (const auto& arg : reader_kernel_config.compile_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "interleaved_dim3_1_1_32_any_reader.cpp", + sender_worker_core_range, + reader_kernel_config); + + // Writer + auto writer_kernel_config = tt::tt_metal::WriterDataMovementConfig{}; + writer_kernel_config.compile_args = { + ring_index, // my_chip_id + reserved_packet_header_CB_index, // reserved_packet_header_cb_id + num_packet_headers_storable, // num_packet_headers_storable + static_cast(output_tensor_buffer_type), // buffer0_type + src0_cb_index, // cb0_id + num_pages_per_packet, // packet_size_in_pages + op_config.get_page_size(), // tensor0_page_size + num_targets_forward, // num_targets_forward_direction + num_targets_backward, // num_targets_backward_direction + }; + log_trace(tt::LogOp, "Writer Compile Args:"); + for (const auto& arg : writer_kernel_config.compile_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "interleaved_dim3_1_1_32_any_writer.cpp", + sender_worker_core_range, + writer_kernel_config); + + // Kernel Runtime Args + CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready + // semaphore + for (uint32_t link = 0; link < num_links; link++) { + CoreCoord core = sender_worker_cores[link]; + if (link == 0) { + // drain sync core is the first worker core + drain_sync_core = device->worker_core_from_logical_core(core); + } + std::optional forward_fabric_connection = + line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::FORWARD)); + std::optional backward_fabric_connection = + line_topology.is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::BACKWARD)); + + // Set reader runtime args + uint32_t base_pages_per_worker = input_tensor_num_pages / num_links; + uint32_t remainder = input_tensor_num_pages % num_links; + uint32_t input_tile_id_start = link * base_pages_per_worker + std::min(link, remainder); + uint32_t input_tile_id_end = (link + 1) * base_pages_per_worker + std::min(link + 1, remainder); + std::vector reader_rt_args = { + input_tensor.buffer()->address(), // tensor_address0 + input_tile_id_start, // tile_id_start + input_tile_id_end, // tile_id_end + }; + log_trace(tt::LogOp, "Reader Runtime Args:"); + for (const auto& arg : reader_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); + + // Set writer runtime args + bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; + bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; + uint32_t out_ready_sem_wait_value = ring_size * num_links; + uint32_t output_tile_id_start = ring_index * input_tensor_num_pages + input_tile_id_start; + uint32_t output_tile_id_end = ring_index * input_tensor_num_pages + input_tile_id_end; + std::vector writer_rt_args = { + output_tensor.buffer()->address(), // tensor_address0 + semaphore.address(), // out_ready_sem_bank_addr (absolute address) + output_tile_id_start, // tile_id_start + output_tile_id_end, // tile_id_end + wait_output_semaphore, // wait_output_semaphore + reset_global_semaphore, // reset_global_semaphore + drain_sync_core.x, // out_ready_sem_noc0_x + drain_sync_core.y, // out_ready_sem_noc0_y + out_ready_sem_wait_value, // out_ready_sem_wait_value + }; + log_trace(tt::LogOp, "Writer Runtime Args:"); + for (const auto& arg : writer_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + append_fabric_connection_rt_args(forward_fabric_connection, core, program, writer_rt_args); + append_fabric_connection_rt_args(backward_fabric_connection, core, program, writer_rt_args); + tt::tt_metal::SetRuntimeArgs(program, worker_sender_writer_kernel_id, {core}, writer_rt_args); + } + + auto override_runtime_arguments_callback = + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; + + auto semaphore = static_cast(operation)->semaphore; + + log_trace(tt::LogOp, "DEBUG: semaphore: {}", semaphore.address()); + + // update senders + auto& worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); + auto& worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); + for (const auto& core : sender_worker_cores) { + // reader + auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; + worker_reader_sender_runtime_args[0] = input.buffer()->address(); + // writer + auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; + worker_writer_sender_runtime_args[0] = output.buffer()->address(); + worker_writer_sender_runtime_args[1] = semaphore.address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +operation::ProgramWithCallbacks all_gather_async_llama_post_binary_matmul( + const Tensor& input_tensor, + std::optional forward_device, + std::optional backward_device, + Tensor& output_tensor, + const uint32_t dim, + const uint32_t num_links, + const uint32_t ring_size, + const uint32_t ring_index, + ccl::Topology topology, + const GlobalSemaphore& semaphore, + const std::optional& sub_device_id, + bool enable_persistent_fabric_mode) { + tt::tt_metal::Program program{}; + const bool enable_async_output_tensor = false; + TT_FATAL( + enable_persistent_fabric_mode, + "only persistent fabric mode is supported for all_gather_async_llama_post_binary_matmul"); + + IDevice* device = input_tensor.device(); + bool is_first_chip = ring_index == 0; + bool is_last_chip = ring_index == ring_size - 1; + log_trace( + tt::LogOp, + "DEBUG: device: {}, is_first_chip: {}, is_last_chip: {}", + input_tensor.device()->id(), + is_first_chip, + is_last_chip); + + std::optional local_fabric_handle = + ttnn::ccl::EdmLineFabricOpInterface::build_program_builder_worker_connection_fabric( + device, + forward_device.value_or(nullptr), + backward_device.value_or(nullptr), + &program, + enable_persistent_fabric_mode, + num_links); + + // Get OP Config, topology config + std::vector input_tensors = {input_tensor}; + std::vector output_tensors = {output_tensor}; + const auto& op_config = ttnn::ccl::CCLOpConfig(input_tensors, output_tensors, topology); + LineTopology line_topology(ring_size, ring_index); + const size_t num_targets_forward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::FORWARD); + const size_t num_targets_backward = + line_topology.get_distance_to_end_of_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD); + + // Get worker cores, assuming 1 worker per link + uint32_t num_workers_per_link = 1; + const auto [sender_worker_core_range, sender_worker_cores] = + choose_worker_cores(num_links, num_workers_per_link, enable_persistent_fabric_mode, device, sub_device_id); + + // Tensor Info + const auto input_tensor_num_pages = input_tensor.buffer()->num_pages(); + const auto input_tensor_cores = input_tensor.memory_config().shard_spec->grid; + const auto input_tensor_shard_shape = input_tensor.memory_config().shard_spec->shape; + const auto input_tensor_shard_num_pages = input_tensor_shard_shape[0] * input_tensor_shard_shape[1] / TILE_HW; + const auto output_tensor_cores = output_tensor.memory_config().shard_spec->grid; + const auto output_tensor_shard_shape = output_tensor.memory_config().shard_spec->shape; + const auto output_tensor_shard_num_pages = output_tensor_shard_shape[0] * output_tensor_shard_shape[1] / TILE_HW; + + tt::log_debug(tt::LogOp, "input_tensor_num_pages: {}", input_tensor_num_pages); + tt::log_debug(tt::LogOp, "input_tensor_cores: {}", input_tensor_cores); + tt::log_debug(tt::LogOp, "input_tensor_shard_shape: {}", input_tensor_shard_shape); + tt::log_debug(tt::LogOp, "input_tensor_shard_num_pages: {}", input_tensor_shard_num_pages); + tt::log_debug(tt::LogOp, "output_tensor_cores: {}", output_tensor_cores); + tt::log_debug(tt::LogOp, "output_tensor_shard_shape: {}", output_tensor_shard_shape); + tt::log_debug(tt::LogOp, "output_tensor_shard_num_pages: {}", output_tensor_shard_num_pages); + + // L1 Scratch CB Creation + const size_t packet_size_bytes = local_fabric_handle->get_edm_buffer_size_bytes(); + uint32_t l1_scratch_cb_page_size_bytes = op_config.get_page_size(); + uint32_t num_pages_per_packet = packet_size_bytes / l1_scratch_cb_page_size_bytes; + uint32_t cb_num_pages = + input_tensor_num_pages / num_links + + 1; // We are dealing with small shapes, so assuming all pages for a worker can be fit into the CB + uint32_t src0_cb_index = tt::CB::c_in0; + tt::DataFormat df = tt::tt_metal::datatype_to_dataformat_converter(input_tensor.get_dtype()); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig(cb_num_pages * l1_scratch_cb_page_size_bytes, {{src0_cb_index, df}}) + .set_page_size(src0_cb_index, l1_scratch_cb_page_size_bytes); + CBHandle cb_src0_workers = CreateCircularBuffer(program, sender_worker_core_range, cb_src0_config); + // Set aside a buffer we can use for storing packet headers in (particularly for atomic incs) + const auto reserved_packet_header_CB_index = tt::CB::c_in1; + static constexpr auto num_packet_headers_storable = 8; + static constexpr auto packet_header_size_bytes = sizeof(tt::fabric::PacketHeader); + tt::tt_metal::CircularBufferConfig cb_reserved_packet_header_config = + tt::tt_metal::CircularBufferConfig( + num_packet_headers_storable * packet_header_size_bytes * 2, + {{reserved_packet_header_CB_index, tt::DataFormat::RawUInt32}}) + .set_page_size(reserved_packet_header_CB_index, packet_header_size_bytes); + auto reserved_packet_header_CB_handle = + CreateCircularBuffer(program, sender_worker_core_range, cb_reserved_packet_header_config); + + // KERNEL CREATION + // Reader + auto reader_kernel_config = tt::tt_metal::ReaderDataMovementConfig{}; + reader_kernel_config.compile_args = { + ring_index, // my_chip_id + src0_cb_index, // cb0_id + op_config.get_page_size(), // tensor0_page_size + }; + log_trace(tt::LogOp, "Reader Compile Args:"); + for (const auto& arg : reader_kernel_config.compile_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + auto worker_sender_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "llama_post_binary_matmul_shape_reader.cpp", + sender_worker_core_range, + reader_kernel_config); + + // Writer + auto writer_kernel_config = tt::tt_metal::WriterDataMovementConfig{}; + writer_kernel_config.compile_args = { + ring_index, // my_chip_id + reserved_packet_header_CB_index, // reserved_packet_header_cb_id + num_packet_headers_storable, // num_packet_headers_storable + src0_cb_index, // cb0_id + num_pages_per_packet, // packet_size_in_pages + op_config.get_page_size(), // tensor0_page_size + num_targets_forward, // num_targets_forward_direction + num_targets_backward, // num_targets_backward_direction + }; + log_trace(tt::LogOp, "Writer Compile Args:"); + for (const auto& arg : writer_kernel_config.compile_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + auto worker_sender_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/" + "llama_post_binary_matmul_shape_writer.cpp", + sender_worker_core_range, + writer_kernel_config); + + // Kernel Runtime Args + CoreCoord drain_sync_core; // the first worker of each chip is the drain sync core, which contains the output ready + // semaphore + auto input_cores_vec = corerange_to_cores(input_tensor_cores, std::nullopt, true); + auto output_cores_vec = corerange_to_cores(output_tensor_cores, std::nullopt, true); + auto cores_per_device = output_cores_vec.size() / ring_size; + TT_FATAL( + output_cores_vec.size() % ring_size == 0, + "output sharded cores must be divisible by num_links for this work distribution scheme"); + auto output_cores_this_device = std::vector( + output_cores_vec.begin() + ring_index * cores_per_device, + output_cores_vec.begin() + (ring_index + 1) * cores_per_device); + + for (uint32_t link = 0; link < num_links; link++) { + CoreCoord core = sender_worker_cores[link]; + + // construct input and output core x and y + uint32_t base_pages_per_worker = input_tensor_num_pages / num_links; + uint32_t remainder = input_tensor_num_pages % num_links; + uint32_t input_tile_id_start = link * base_pages_per_worker + std::min(link, remainder); + uint32_t input_tile_id_end = (link + 1) * base_pages_per_worker + std::min(link + 1, remainder); + + uint32_t worker_num_tiles_to_read = input_tile_id_end - input_tile_id_start; + uint32_t input_first_core_tile_start_offset = input_tile_id_start % input_tensor_shard_num_pages; + uint32_t output_first_core_tile_start_offset = + (input_tensor_num_pages * ring_index + input_tile_id_start) % output_tensor_shard_num_pages; + + std::vector input_tensor_cores_x; + std::vector input_tensor_cores_y; + std::vector output_tensor_cores_x; + std::vector output_tensor_cores_y; + for (uint32_t i = input_tile_id_start / input_tensor_shard_num_pages; + i < (input_tile_id_end + input_tensor_shard_num_pages - 1) / input_tensor_shard_num_pages; + i++) { + auto this_core = device->worker_core_from_logical_core(input_cores_vec[i]); + input_tensor_cores_x.push_back(this_core.x); + input_tensor_cores_y.push_back(this_core.y); + } + for (uint32_t i = input_tile_id_start / output_tensor_shard_num_pages; + i < (input_tile_id_end + output_tensor_shard_num_pages - 1) / output_tensor_shard_num_pages; + i++) { + auto this_core = device->worker_core_from_logical_core(output_cores_this_device[i]); + output_tensor_cores_x.push_back(this_core.x); + output_tensor_cores_y.push_back(this_core.y); + } + + tt::log_debug(tt::LogOp, "input_tile_id_start: {}", input_tile_id_start); + tt::log_debug(tt::LogOp, "input_tile_id_end: {}", input_tile_id_end); + tt::log_debug(tt::LogOp, "worker_num_tiles_to_read: {}", worker_num_tiles_to_read); + tt::log_debug(tt::LogOp, "input_first_core_tile_start_offset: {}", input_first_core_tile_start_offset); + tt::log_debug(tt::LogOp, "output_first_core_tile_start_offset: {}", output_first_core_tile_start_offset); + tt::log_debug(tt::LogOp, "input_tensor_cores_x: {}", input_tensor_cores_x); + tt::log_debug(tt::LogOp, "input_tensor_cores_y: {}", input_tensor_cores_y); + tt::log_debug(tt::LogOp, "output_tensor_cores_x: {}", output_tensor_cores_x); + tt::log_debug(tt::LogOp, "output_tensor_cores_y: {}", output_tensor_cores_y); + + if (link == 0) { + // drain sync core is the first worker core + drain_sync_core = device->worker_core_from_logical_core(core); + } + std::optional forward_fabric_connection = + line_topology.is_first_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::FORWARD)); + std::optional backward_fabric_connection = + line_topology.is_last_device_in_line(ttnn::ccl::EdmLineFabricOpInterface::Direction::BACKWARD) + ? std::nullopt + : std::optional(local_fabric_handle->uniquely_connect_worker( + device, ttnn::ccl::EdmLineFabricOpInterface::BACKWARD)); + + // Set reader runtime args + std::vector reader_rt_args = { + input_tensor.buffer()->address(), // tensor_address0 + input_tensor_shard_num_pages, // num_tiles_per_core + worker_num_tiles_to_read, // num_tiles_to_read + input_first_core_tile_start_offset, // first_core_tile_start_offset + input_tensor_cores_x.size(), // num_cores + }; + reader_rt_args.insert(reader_rt_args.end(), input_tensor_cores_x.begin(), input_tensor_cores_x.end()); + reader_rt_args.insert(reader_rt_args.end(), input_tensor_cores_y.begin(), input_tensor_cores_y.end()); + log_trace(tt::LogOp, "Reader Runtime Args:"); + for (const auto& arg : reader_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + tt::tt_metal::SetRuntimeArgs(program, worker_sender_reader_kernel_id, {core}, reader_rt_args); + + // Set writer runtime args + bool wait_output_semaphore = (link == 0) && !enable_async_output_tensor; + bool reset_global_semaphore = (link == 0) && !enable_async_output_tensor; + uint32_t out_ready_sem_wait_value = ring_size * num_links; + std::vector writer_rt_args = { + output_tensor.buffer()->address(), // tensor_address0 + semaphore.address(), // out_ready_sem_bank_addr (absolute address) + output_tensor_shard_num_pages, // num_tiles_per_core + worker_num_tiles_to_read, // num_tiles_to_read + output_first_core_tile_start_offset, // first_core_tile_start_offset + output_tensor_cores_x.size(), // num_cores + wait_output_semaphore, // wait_output_semaphore + reset_global_semaphore, // reset_global_semaphore + drain_sync_core.x, // out_ready_sem_noc0_x + drain_sync_core.y, // out_ready_sem_noc0_y + out_ready_sem_wait_value, // out_ready_sem_wait_value + }; + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_x.begin(), output_tensor_cores_x.end()); + writer_rt_args.insert(writer_rt_args.end(), output_tensor_cores_y.begin(), output_tensor_cores_y.end()); + log_trace(tt::LogOp, "Writer Runtime Args:"); + for (const auto& arg : writer_rt_args) { + log_trace(tt::LogOp, "\t{}", arg); + } + append_fabric_connection_rt_args(forward_fabric_connection, core, program, writer_rt_args); + append_fabric_connection_rt_args(backward_fabric_connection, core, program, writer_rt_args); + tt::tt_metal::SetRuntimeArgs(program, worker_sender_writer_kernel_id, {core}, writer_rt_args); + } + + auto override_runtime_arguments_callback = + [worker_sender_reader_kernel_id, worker_sender_writer_kernel_id, semaphore, sender_worker_cores]( + const void* operation, + Program& program, + const std::vector& input_tensors, + const std::vector>& optional_input_tensors, + const std::vector& output_tensors) { + const auto& input = input_tensors[0]; + const auto& output = output_tensors[0]; + + auto semaphore = static_cast(operation)->semaphore; + + log_trace(tt::LogOp, "DEBUG: semaphore: {}", semaphore.address()); + + // update senders + auto& worker_reader_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_reader_kernel_id); + auto& worker_writer_sender_runtime_args_by_core = GetRuntimeArgs(program, worker_sender_writer_kernel_id); + for (const auto& core : sender_worker_cores) { + // reader + auto& worker_reader_sender_runtime_args = worker_reader_sender_runtime_args_by_core[core.x][core.y]; + worker_reader_sender_runtime_args[0] = input.buffer()->address(); + // writer + auto& worker_writer_sender_runtime_args = worker_writer_sender_runtime_args_by_core[core.x][core.y]; + worker_writer_sender_runtime_args[0] = output.buffer()->address(); + worker_writer_sender_runtime_args[1] = semaphore.address(); + } + }; + + return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; +} + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/interleaved_dim3_1_1_32_any_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/interleaved_dim3_1_1_32_any_reader.cpp new file mode 100644 index 00000000000..69a91f57fe2 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/interleaved_dim3_1_1_32_any_reader.cpp @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include +#include + +using address_t = uint32_t; +using tt::tt_metal::BufferType; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr BufferType buffer0_type = static_cast(get_compile_time_arg_val(1)); +constexpr uint32_t cb0_id = get_compile_time_arg_val(2); +constexpr uint32_t packet_size_in_pages = get_compile_time_arg_val(3); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(4); + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + uint32_t tile_id_start = get_arg_val(arg_idx++); + uint32_t tile_id_end = get_arg_val(arg_idx++); + + // print every compile and runtime arg in uint32_t + DPRINT << "ct args: \n"; + DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; + DPRINT << "buffer0_type: " << (uint32_t)buffer0_type << "\n"; + DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; + DPRINT << "packet_size_in_pages: " << (uint32_t)packet_size_in_pages << "\n"; + DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; + + DPRINT << "rt args: \n"; + DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; + DPRINT << "tile_id_start: " << (uint32_t)tile_id_start << "\n"; + DPRINT << "tile_id_end: " << (uint32_t)tile_id_end << "\n"; + + // interleaved addrgen + constexpr bool is_dram = buffer0_type == tt::tt_metal::BufferType::DRAM; + auto tensor0_addrgen = InterleavedAddrGenFast{ + .bank_base_address = tensor_address0, .page_size = tensor0_page_size, .data_format = get_dataformat(cb0_id)}; + + DPRINT << "tensor -> CB: " << (uint32_t)cb0_id << "\n"; + DPRINT << "packet size in pages: " << (uint32_t)packet_size_in_pages << "\n"; + + uint32_t tile_id = tile_id_start; + while (tile_id < tile_id_end) { + DPRINT << "tile_id: " << tile_id << "\n"; + cb_reserve_back(cb0_id, packet_size_in_pages); + const uint32_t l1_write_addr_base = get_write_ptr(cb0_id); + uint32_t l1_write_addr = l1_write_addr_base; + + uint32_t num_pages_to_read = std::min(tile_id_end - tile_id, packet_size_in_pages); + for (uint32_t j = 0; j < num_pages_to_read; j++) { + noc_async_read_tile(tile_id, tensor0_addrgen, l1_write_addr); + l1_write_addr += tensor0_page_size; + tile_id++; + } + + noc_async_read_barrier(); + cb_push_back(cb0_id, packet_size_in_pages); + } + + DPRINT << "DONE \n"; +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/interleaved_dim3_1_1_32_any_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/interleaved_dim3_1_1_32_any_writer.cpp new file mode 100644 index 00000000000..003d5934ded --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/interleaved_dim3_1_1_32_any_writer.cpp @@ -0,0 +1,202 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "minimal_ccl_common.hpp" +#include +#include + +using address_t = uint32_t; +using tt::tt_metal::BufferType; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t reserved_packet_header_cb_id = get_compile_time_arg_val(1); +constexpr uint32_t num_packet_headers_storable = get_compile_time_arg_val(2); +constexpr BufferType buffer0_type = static_cast(get_compile_time_arg_val(3)); +constexpr uint32_t cb0_id = get_compile_time_arg_val(4); +constexpr uint32_t packet_size_in_pages = get_compile_time_arg_val(5); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(6); +constexpr uint32_t num_targets_forward_direction = get_compile_time_arg_val(7); +constexpr uint32_t num_targets_backward_direction = get_compile_time_arg_val(8); + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + const size_t out_ready_sem_bank_addr = get_arg_val(arg_idx++); + uint32_t tile_id_start = get_arg_val(arg_idx++); + uint32_t tile_id_end = get_arg_val(arg_idx++); + bool wait_output_semaphore = get_arg_val(arg_idx++); + bool reset_global_semaphore = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); + uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); + size_t arg_for_fab = arg_idx; + auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); + + DPRINT << "ct args: \n"; + DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; + DPRINT << "reserved_packet_header_cb_id: " << (uint32_t)reserved_packet_header_cb_id << "\n"; + DPRINT << "num_packet_headers_storable: " << (uint32_t)num_packet_headers_storable << "\n"; + DPRINT << "buffer0_type: " << (uint32_t)buffer0_type << "\n"; + DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; + DPRINT << "packet_size_in_pages: " << (uint32_t)packet_size_in_pages << "\n"; + DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; + DPRINT << "num_targets_forward_direction: " << (uint32_t)num_targets_forward_direction << "\n"; + DPRINT << "num_targets_backward_direction: " << (uint32_t)num_targets_backward_direction << "\n"; + + DPRINT << "rt args: \n"; + DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; + DPRINT << "tile_id_start: " << (uint32_t)tile_id_start << "\n"; + DPRINT << "tile_id_end: " << (uint32_t)tile_id_end << "\n"; + DPRINT << "wait_output_semaphore: " << (uint32_t)wait_output_semaphore << "\n"; + DPRINT << "reset_global_semaphore: " << (uint32_t)reset_global_semaphore << "\n"; + DPRINT << "out_ready_sem_bank_addr: " << (uint32_t)out_ready_sem_bank_addr << "\n"; + DPRINT << "out_ready_sem_noc0_x: " << (uint32_t)out_ready_sem_noc0_x << "\n"; + DPRINT << "out_ready_sem_noc0_y: " << (uint32_t)out_ready_sem_noc0_y << "\n"; + DPRINT << "out_ready_sem_wait_value: " << (uint32_t)out_ready_sem_wait_value << "\n"; + + DPRINT << "arg_for_fab: " << (uint32_t)arg_for_fab << "\n"; + DPRINT << "fabric_connection arg 0" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 1" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 2" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 3" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 4" << get_arg_val(arg_for_fab++) << "\n"; + + // packet header cb + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_forward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_backward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_seminc = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + DPRINT << "packet_header_buffer_addr_forward: " << (uint32_t)packet_header_buffer_addr_forward << "\n"; + DPRINT << "packet_header_buffer_addr_backward: " << (uint32_t)packet_header_buffer_addr_backward << "\n"; + DPRINT << "packet_header_buffer_seminc: " << (uint32_t)packet_header_buffer_seminc << "\n"; + + // pre-populate packet headers + volatile tt::fabric::PacketHeader* pkt_hdr_forward = + reinterpret_cast(packet_header_buffer_addr_forward); + volatile tt::fabric::PacketHeader* pkt_hdr_backward = + reinterpret_cast(packet_header_buffer_addr_backward); + pkt_hdr_forward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + pkt_hdr_backward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + + // interleaved addrgen + constexpr bool is_dram = buffer0_type == tt::tt_metal::BufferType::DRAM; + auto tensor0_addrgen = InterleavedAddrGenFast{ + .bank_base_address = tensor_address0, .page_size = tensor0_page_size, .data_format = get_dataformat(cb0_id)}; + + if (fabric_connection.is_logically_connected()) { + fabric_connection.open(); + } + + // 1. mcast via fabric to remote tensor addresses + DPRINT << "num_targets_forward_direction: " << num_targets_forward_direction << "\n"; + DPRINT << "num_targets_backward_direction: " << num_targets_backward_direction << "\n"; + DPRINT << "my_chip_id: " << my_chip_id << "\n"; + + DPRINT << "tensor -> CB: " << (uint32_t)cb0_id << "\n"; + DPRINT << "packet size in pages: " << (uint32_t)packet_size_in_pages << "\n"; + uint32_t tile_id = tile_id_start; + while (tile_id < tile_id_end) { + DPRINT << "tile_id: " << tile_id << "\n"; + cb_wait_front(cb0_id, packet_size_in_pages); + size_t l1_read_addr = get_read_ptr(cb0_id); + uint32_t num_pages_to_read = std::min(tile_id_end - tile_id, packet_size_in_pages); + + uint32_t contig_pages_advanced = 1; // always 1 for interleaved + for (uint32_t j = 0; j < num_pages_to_read; j += contig_pages_advanced) { + uint64_t noc0_dest_noc_addr = get_noc_addr(tile_id, tensor0_addrgen, 0 /*offset*/, 0 /*noc_id*/); + + DPRINT << "j: " << j << "\n"; + DPRINT << "noc0_dest_noc_addr: " << noc0_dest_noc_addr << "\n"; + DPRINT << "tile_id: " << tile_id << "\n"; + + write_and_advance_local_read_address_for_fabric_write( + noc0_dest_noc_addr, + pkt_hdr_forward, + pkt_hdr_backward, + fabric_connection, + l1_read_addr, + contig_pages_advanced * tensor0_page_size); + + tile_id++; + } + noc_async_writes_flushed(); + + cb_pop_front(cb0_id, packet_size_in_pages); + } + + // 2. mcast output ready semaphore + auto* pkt_hdr = reinterpret_cast(packet_header_buffer_seminc); + pkt_hdr->to_atomic_inc(); + pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ + out_ready_sem_bank_addr, + static_cast(1), // increment 1 + 32, + static_cast(out_ready_sem_noc0_x), + static_cast(out_ready_sem_noc0_y)}); + // Write the mcast packet (forward) + if (fabric_connection.has_forward_connection()) { + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); + } + // Write the mcast packet (backward) + if (fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_non_blocking_from_address( + packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); + } + // increment locally + uint64_t out_ready_sem_noc_addr = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); + noc_semaphore_inc(out_ready_sem_noc_addr, 1); + DPRINT << "inc done\n"; + + // 3. wait for mcast output ready semaphore + if (wait_output_semaphore) { + while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); + DPRINT << "waitval done\n"; + } + + // 4. global semaphore reset + if (reset_global_semaphore) { + const uint64_t dest_noc_addr = get_noc_addr(my_x[0], my_y[0], out_ready_sem_bank_addr); + noc_inline_dw_write(dest_noc_addr, 0); + DPRINT << "reset done\n"; + } + + if (fabric_connection.is_logically_connected()) { + fabric_connection.close(); + } + + noc_async_write_barrier(); + DPRINT << "DONE \n"; +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp new file mode 100644 index 00000000000..e6403fec74e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_reader.cpp @@ -0,0 +1,84 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include +#include + +using address_t = uint32_t; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t cb0_id = get_compile_time_arg_val(1); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(2); + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); + uint32_t num_tiles_to_read = get_arg_val(arg_idx++); + uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); + uint32_t num_cores = get_arg_val(arg_idx++); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + + // print every compile and runtime arg in uint32_t + DPRINT << "ct args: \n"; + DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; + DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; + DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; + + DPRINT << "rt args: \n"; + DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; + DPRINT << "num_tiles_per_core: " << (uint32_t)num_tiles_per_core << "\n"; + DPRINT << "num_tiles_to_read: " << (uint32_t)num_tiles_to_read << "\n"; + DPRINT << "first_core_tile_start_offset: " << (uint32_t)first_core_tile_start_offset << "\n"; + DPRINT << "num_cores: " << (uint32_t)num_cores << "\n"; + for (uint32_t i = 0; i < num_cores; i++) { + DPRINT << "core_noc_x[" << i << "]: " << (uint32_t)core_noc_x[i] << "\n"; + DPRINT << "core_noc_y[" << i << "]: " << (uint32_t)core_noc_y[i] << "\n"; + } + + // interleaved addrgen + + DPRINT << "tensor -> CB: " << (uint32_t)cb0_id << "\n"; + + uint32_t tiles_read = 0; + uint32_t shard_tile_id = first_core_tile_start_offset; + uint32_t core_id = 0; + while (tiles_read < num_tiles_to_read) { + DPRINT << "tiles_read: " << tiles_read << "\n"; + uint32_t num_tiles_to_read_this_core = + std::min(num_tiles_per_core - shard_tile_id, num_tiles_to_read - tiles_read); + cb_reserve_back(cb0_id, num_tiles_to_read_this_core); + const uint32_t l1_write_addr = get_write_ptr(cb0_id); + uint64_t read_addr = get_noc_addr(core_noc_x[core_id], core_noc_y[core_id], tensor_address0); + read_addr += shard_tile_id * tensor0_page_size; + + noc_async_read(read_addr, l1_write_addr, num_tiles_to_read_this_core * tensor0_page_size); + noc_async_read_barrier(); + + cb_push_back(cb0_id, num_tiles_to_read_this_core); + tiles_read += num_tiles_to_read_this_core; + shard_tile_id = 0; + core_id++; + } + + DPRINT << "DONE \n"; +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp new file mode 100644 index 00000000000..54bfa996d39 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/llama_post_binary_matmul_shape_writer.cpp @@ -0,0 +1,210 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include "minimal_ccl_common.hpp" +#include +#include + +using address_t = uint32_t; + +/////////////////////////////////////////////////// +// COMPILE TIME ARGS +/////////////////////////////////////////////////// + +constexpr uint32_t my_chip_id = get_compile_time_arg_val(0); +constexpr uint32_t reserved_packet_header_cb_id = get_compile_time_arg_val(1); +constexpr uint32_t num_packet_headers_storable = get_compile_time_arg_val(2); +constexpr uint32_t cb0_id = get_compile_time_arg_val(3); +constexpr uint32_t packet_size_in_pages = get_compile_time_arg_val(4); +constexpr uint32_t tensor0_page_size = get_compile_time_arg_val(5); +constexpr uint32_t num_targets_forward_direction = get_compile_time_arg_val(6); +constexpr uint32_t num_targets_backward_direction = get_compile_time_arg_val(7); + +/* + * CCL Send will present various operating modes. Although there is only a single send kernel, it may (compile time) + * dispatch implementations depending on those invocation parameters. + */ +void kernel_main() { + /////////////////////////////////////////////////// + // ARGS + /////////////////////////////////////////////////// + + size_t arg_idx = 0; + // Load the input tensor spec + address_t tensor_address0 = get_arg_val(arg_idx++); + const size_t out_ready_sem_bank_addr = get_arg_val(arg_idx++); + uint32_t num_tiles_per_core = get_arg_val(arg_idx++); + uint32_t num_tiles_to_read = get_arg_val(arg_idx++); + uint32_t first_core_tile_start_offset = get_arg_val(arg_idx++); + uint32_t num_cores = get_arg_val(arg_idx++); + bool wait_output_semaphore = get_arg_val(arg_idx++); + bool reset_global_semaphore = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_x = get_arg_val(arg_idx++); + const uint8_t out_ready_sem_noc0_y = get_arg_val(arg_idx++); + uint32_t out_ready_sem_wait_value = get_arg_val(arg_idx++); + tt_l1_ptr uint32_t* core_noc_x = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + tt_l1_ptr uint32_t* core_noc_y = (tt_l1_ptr uint32_t*)(get_arg_addr(arg_idx)); + arg_idx += num_cores; + size_t arg_for_fab = arg_idx; + auto fabric_connection = FabricConnectionManager::build_from_args(arg_idx); + + DPRINT << "ct args: \n"; + DPRINT << "my_chip_id: " << (uint32_t)my_chip_id << "\n"; + DPRINT << "reserved_packet_header_cb_id: " << (uint32_t)reserved_packet_header_cb_id << "\n"; + DPRINT << "num_packet_headers_storable: " << (uint32_t)num_packet_headers_storable << "\n"; + DPRINT << "cb0_id: " << (uint32_t)cb0_id << "\n"; + DPRINT << "packet_size_in_pages: " << (uint32_t)packet_size_in_pages << "\n"; + DPRINT << "tensor0_page_size: " << (uint32_t)tensor0_page_size << "\n"; + DPRINT << "num_targets_forward_direction: " << (uint32_t)num_targets_forward_direction << "\n"; + DPRINT << "num_targets_backward_direction: " << (uint32_t)num_targets_backward_direction << "\n"; + + DPRINT << "rt args: \n"; + DPRINT << "tensor_address0: " << (uint32_t)tensor_address0 << "\n"; + DPRINT << "num_tiles_per_core: " << (uint32_t)num_tiles_per_core << "\n"; + DPRINT << "num_tiles_to_read: " << (uint32_t)num_tiles_to_read << "\n"; + DPRINT << "first_core_tile_start_offset: " << (uint32_t)first_core_tile_start_offset << "\n"; + DPRINT << "num_cores: " << (uint32_t)num_cores << "\n"; + for (uint32_t i = 0; i < num_cores; i++) { + DPRINT << "core_noc_x[" << i << "]: " << (uint32_t)core_noc_x[i] << "\n"; + DPRINT << "core_noc_y[" << i << "]: " << (uint32_t)core_noc_y[i] << "\n"; + } + DPRINT << "wait_output_semaphore: " << (uint32_t)wait_output_semaphore << "\n"; + DPRINT << "reset_global_semaphore: " << (uint32_t)reset_global_semaphore << "\n"; + DPRINT << "out_ready_sem_bank_addr: " << (uint32_t)out_ready_sem_bank_addr << "\n"; + DPRINT << "out_ready_sem_noc0_x: " << (uint32_t)out_ready_sem_noc0_x << "\n"; + DPRINT << "out_ready_sem_noc0_y: " << (uint32_t)out_ready_sem_noc0_y << "\n"; + DPRINT << "out_ready_sem_wait_value: " << (uint32_t)out_ready_sem_wait_value << "\n"; + + DPRINT << "arg_for_fab: " << (uint32_t)arg_for_fab << "\n"; + DPRINT << "fabric_connection arg 0" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 1" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 2" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 3" << get_arg_val(arg_for_fab++) << "\n"; + DPRINT << "fabric_connection arg 4" << get_arg_val(arg_for_fab++) << "\n"; + + // packet header cb + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_forward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_addr_backward = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + cb_reserve_back(reserved_packet_header_cb_id, 1); + auto packet_header_buffer_seminc = get_write_ptr(reserved_packet_header_cb_id); + cb_push_back(reserved_packet_header_cb_id, 1); + DPRINT << "packet_header_buffer_addr_forward: " << (uint32_t)packet_header_buffer_addr_forward << "\n"; + DPRINT << "packet_header_buffer_addr_backward: " << (uint32_t)packet_header_buffer_addr_backward << "\n"; + DPRINT << "packet_header_buffer_seminc: " << (uint32_t)packet_header_buffer_seminc << "\n"; + + // pre-populate packet headers + volatile tt::fabric::PacketHeader* pkt_hdr_forward = + reinterpret_cast(packet_header_buffer_addr_forward); + volatile tt::fabric::PacketHeader* pkt_hdr_backward = + reinterpret_cast(packet_header_buffer_addr_backward); + pkt_hdr_forward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + pkt_hdr_backward->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + + if (fabric_connection.is_logically_connected()) { + fabric_connection.open(); + } + + // 1. mcast via fabric to remote tensor addresses + uint32_t tiles_read = 0; + uint32_t shard_tile_id = first_core_tile_start_offset; + uint32_t core_id = 0; + while (tiles_read < num_tiles_to_read) { + DPRINT << "tiles_read: " << tiles_read << "\n"; + uint32_t num_tiles_to_read_this_core = std::min(num_tiles_per_core - shard_tile_id, packet_size_in_pages); + num_tiles_to_read_this_core = std::min(num_tiles_to_read - tiles_read, num_tiles_to_read_this_core); + cb_wait_front(cb0_id, num_tiles_to_read_this_core); + size_t l1_read_addr = get_read_ptr(cb0_id); + + uint64_t noc0_dest_noc_addr = + get_noc_addr(core_noc_x[core_id], core_noc_y[core_id], tensor_address0, 0 /*noc_id*/); + DPRINT << "core_noc_x[core_id]: " << (uint32_t)core_noc_x[core_id] << "\n"; + DPRINT << "core_noc_y[core_id]: " << (uint32_t)core_noc_y[core_id] << "\n"; + DPRINT << "noc0_dest_noc_addr_base: " << noc0_dest_noc_addr << "\n"; + noc0_dest_noc_addr += shard_tile_id * tensor0_page_size; + + DPRINT << "core_id: " << core_id << "\n"; + DPRINT << "num_tiles_to_read_this_core: " << num_tiles_to_read_this_core << "\n"; + DPRINT << "noc0_dest_noc_addr: " << noc0_dest_noc_addr << "\n"; + DPRINT << "shard_tile_id: " << shard_tile_id << "\n"; + + write_and_advance_local_read_address_for_fabric_write( + noc0_dest_noc_addr, + pkt_hdr_forward, + pkt_hdr_backward, + fabric_connection, + l1_read_addr, + num_tiles_to_read_this_core * tensor0_page_size); + noc_async_writes_flushed(); + + cb_pop_front(cb0_id, num_tiles_to_read_this_core); + tiles_read += num_tiles_to_read_this_core; + shard_tile_id += num_tiles_to_read_this_core; + if (shard_tile_id >= num_tiles_per_core) { + shard_tile_id = 0; + core_id++; + } + } + + // 2. mcast output ready semaphore + auto* pkt_hdr = reinterpret_cast(packet_header_buffer_seminc); + pkt_hdr->to_atomic_inc(); + pkt_hdr->to_noc_unicast_atomic_inc(tt::fabric::NocUnicastAtomicIncCommandHeader{ + out_ready_sem_bank_addr, + static_cast(1), // increment 1 + 32, + static_cast(out_ready_sem_noc0_x), + static_cast(out_ready_sem_noc0_y)}); + // Write the mcast packet (forward) + if (fabric_connection.has_forward_connection()) { + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_forward_direction)}); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); + } + // Write the mcast packet (backward) + if (fabric_connection.has_backward_connection()) { + pkt_hdr->to_chip_multicast( + tt::fabric::MulticastRoutingCommandHeader{1, static_cast(num_targets_backward_direction)}); + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_non_blocking_from_address( + packet_header_buffer_seminc, sizeof(tt::fabric::PacketHeader)); + } + // increment locally + uint64_t out_ready_sem_noc_addr = + safe_get_noc_addr(out_ready_sem_noc0_x, out_ready_sem_noc0_y, out_ready_sem_bank_addr); + noc_semaphore_inc(out_ready_sem_noc_addr, 1); + DPRINT << "inc done\n"; + + // 3. wait for mcast output ready semaphore + if (wait_output_semaphore) { + while (*reinterpret_cast(out_ready_sem_bank_addr) < out_ready_sem_wait_value); + DPRINT << "waitval done\n"; + } + + // 4. global semaphore reset + if (reset_global_semaphore) { + const uint64_t dest_noc_addr = get_noc_addr(my_x[0], my_y[0], out_ready_sem_bank_addr); + noc_inline_dw_write(dest_noc_addr, 0); + DPRINT << "reset done\n"; + } + + if (fabric_connection.is_logically_connected()) { + fabric_connection.close(); + } + + noc_async_write_barrier(); + DPRINT << "DONE \n"; +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp new file mode 100644 index 00000000000..777010fb399 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/kernels/minimal_ccl_common.hpp @@ -0,0 +1,48 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/fabric_connection_manager.hpp" +#include "cpp/ttnn/operations/ccl/common/interpreter_backends/kernel_common/noc_addr.hpp" +#include +#include + +FORCE_INLINE void write_and_advance_local_read_address_for_fabric_write( + uint64_t noc0_dest_noc_addr, + volatile tt::fabric::PacketHeader* pkt_hdr_forward, + volatile tt::fabric::PacketHeader* pkt_hdr_backward, + FabricConnectionManager& fabric_connection, + size_t& l1_read_addr, + uint32_t payload_size_bytes) { + const auto [dest_noc_xy, dest_addr] = get_noc_address_components(noc0_dest_noc_addr); + const size_t payload_l1_address = l1_read_addr; + + size_t packet_send_size_bytes = payload_size_bytes + sizeof(tt::fabric::PacketHeader); + pkt_hdr_forward->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); + pkt_hdr_backward->to_write()->to_noc_unicast(tt::fabric::NocUnicastCommandHeader{ + dest_addr, packet_send_size_bytes, static_cast(dest_noc_xy.x), static_cast(dest_noc_xy.y)}); + + noc_async_write(payload_l1_address, safe_get_noc_addr(dest_noc_xy.x, dest_noc_xy.y, dest_addr), payload_size_bytes); + if (fabric_connection.has_forward_connection()) { + fabric_connection.get_forward_connection().wait_for_empty_write_slot(); + fabric_connection.get_forward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_forward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr_forward, sizeof(tt::fabric::PacketHeader)); + } + + if (fabric_connection.has_backward_connection()) { + fabric_connection.get_backward_connection().wait_for_empty_write_slot(); + fabric_connection.get_backward_connection().send_payload_without_header_non_blocking_from_address( + l1_read_addr, payload_size_bytes); + fabric_connection.get_backward_connection().send_payload_flush_blocking_from_address( + (uint32_t)pkt_hdr_backward, sizeof(tt::fabric::PacketHeader)); + } + + noc_async_writes_flushed(); + + l1_read_addr += payload_size_bytes; +} diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index 47268e789b1..15bd0227fba 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -7,6 +7,7 @@ #include "ttnn/operations/math.hpp" #include "ttnn/tensor/tensor_utils.hpp" #include "ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp" +#include "ttnn/operations/ccl/sharding_addrgen_helper.hpp" /* All Gather Matmul fusion includes */ #include "cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp" @@ -57,8 +58,7 @@ void AllGatherMatmul::validate( auto const& shard_grid = all_gather_output_tensor_shard_spec->grid.bounding_box(); auto const& shard_grid_start = shard_grid.start_coord; auto const& shard_grid_end = shard_grid.end_coord; - const uint32_t num_all_gather_output_shards = - (shard_grid_end.y - shard_grid_start.y + 1) * (shard_grid_end.x - shard_grid_start.x + 1); + const uint32_t num_all_gather_output_shards = shard_builder::get_sharding_core_count(all_gather_output_tensor); TT_FATAL( this->all_gather_struct.ring_size == num_all_gather_output_shards, "AllGatherMatmul requires number of tensor slices to equal the number of output shards of the all_gather."); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp index e6c804523ad..fe431c64c4b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -25,8 +25,8 @@ ReduceScatterAsync create_reduce_scatter_struct( std::optional> forward_output_tensors, std::optional> backward_output_tensors, std::optional num_links_preferred, - const std::vector>& from_remote_sems, - const std::vector>& to_remote_sems, + const std::vector& from_remote_sems, + const std::vector& to_remote_sems, std::optional sub_device_id, std::optional& fabric_handle) { uint32_t num_devices = devices.size(); @@ -54,8 +54,8 @@ ReduceScatterAsync create_reduce_scatter_struct( return *device; }; - std::shared_ptr from_remote_sem = from_remote_sems.at(device_index); - std::shared_ptr to_remote_sem = to_remote_sems.at(device_index); + GlobalSemaphore from_remote_sem = from_remote_sems.at(device_index); + GlobalSemaphore to_remote_sem = to_remote_sems.at(device_index); return ttnn::ReduceScatterAsync{ binary_op_type, @@ -226,16 +226,11 @@ Tensor reduce_scatter( rank - 1, dim); - // get shared_ptr from multi_device_global_semaphore - std::vector> from_remote_inputs_semaphores; - for (auto& sem : from_remote_multi_device_global_semaphore.global_semaphores) { - from_remote_inputs_semaphores.push_back(std::make_shared(sem)); - } + std::vector from_remote_inputs_semaphores = + from_remote_multi_device_global_semaphore.global_semaphores; - std::vector> to_remote_inputs_semaphores; - for (auto& sem : to_remote_multi_device_global_semaphore.global_semaphores) { - to_remote_inputs_semaphores.push_back(std::make_shared(sem)); - } + std::vector to_remote_inputs_semaphores = + to_remote_multi_device_global_semaphore.global_semaphores; std::vector output_tensors = { Tensor(operation::get_workers_for_op_output({input_tensor})), @@ -306,16 +301,11 @@ Tensor reduce_scatter( const auto mesh_view = mesh_device.get_view(); auto devices = input_tensor.get_workers(); - // get shared_ptr from multi_device_global_semaphore - std::vector> from_remote_inputs_semaphores; - for (auto& sem : from_remote_multi_device_global_semaphore.global_semaphores) { - from_remote_inputs_semaphores.push_back(std::make_shared(sem)); - } + std::vector from_remote_inputs_semaphores = + from_remote_multi_device_global_semaphore.global_semaphores; - std::vector> to_remote_inputs_semaphores; - for (auto& sem : to_remote_multi_device_global_semaphore.global_semaphores) { - to_remote_inputs_semaphores.push_back(std::make_shared(sem)); - } + std::vector to_remote_inputs_semaphores = + to_remote_multi_device_global_semaphore.global_semaphores; std::vector output_tensors = { Tensor(operation::get_workers_for_op_output({input_tensor})), diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp index bfc9789c5cc..c6256e6d734 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp @@ -24,8 +24,8 @@ struct ReduceScatterAsync { std::optional>& foreward_output_tensors, std::optional>& backward_output_tensors, std::optional num_links_preferred, - const std::shared_ptr& from_remote_sem, - const std::shared_ptr& to_remote_sem, + const GlobalSemaphore& from_remote_sem, + const GlobalSemaphore& to_remote_sem, std::optional& sub_device_id, std::optional& fabric_handle) : binary_op_type(binary_op_type), @@ -56,8 +56,8 @@ struct ReduceScatterAsync { std::optional> foreward_output_tensors; std::optional> backward_output_tensors; std::optional num_links_preferred; - std::shared_ptr from_remote_sem; - std::shared_ptr to_remote_sem; + const GlobalSemaphore from_remote_sem; + const GlobalSemaphore to_remote_sem; std::optional& fabric_handle; std::optional sub_device_id; @@ -104,8 +104,8 @@ operation::ProgramWithCallbacks build_reduce_scatter_async_program( const uint32_t line_index, ttnn::ccl::Topology topology, std::optional num_links_preferred, - const std::shared_ptr& from_remote_sem, - const std::shared_ptr& to_remote_sem, + const GlobalSemaphore& from_remote_sem, + const GlobalSemaphore& to_remote_sem, const std::optional& sub_device_id, std::optional& fabric_handle); } @@ -123,8 +123,8 @@ ReduceScatterAsync create_reduce_scatter_struct( std::optional> foreward_output_tensors, std::optional> backward_output_tensors, std::optional num_links_preferred, - const std::vector>& from_remote_sems, - const std::vector>& to_remote_sems, + const std::vector& from_remote_sems, + const std::vector& to_remote_sems, std::optional sub_device_id, std::optional& fabric_handle); } // namespace reduce_scatter_detail diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp index ed497031e0b..11447364c4f 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_program.cpp @@ -1813,8 +1813,8 @@ static void initialize_op_internal_tensor_syncs( std::array const& neighbour_devices, ProgramTensorsBundle& all_tensors, WorkerCoreBundle const& worker_cores, - std::shared_ptr const& from_remote_sem, - std::shared_ptr const& to_remote_sem) { + GlobalSemaphore const& from_remote_sem, + GlobalSemaphore const& to_remote_sem) { auto core_coord_lt = [](CoreCoord const& a, CoreCoord const& b) { return a.y < b.y || (a.y == b.y && a.x < b.x); }; TT_FATAL( @@ -1836,12 +1836,12 @@ static void initialize_op_internal_tensor_syncs( device->worker_core_from_logical_core(worker_core).x, device->worker_core_from_logical_core(worker_core).y, }); - all_tensors.input_tensor_from_remote_sync[direction].semaphore_ids.push_back(from_remote_sem.get()); + all_tensors.input_tensor_from_remote_sync[direction].semaphore_ids.push_back(&from_remote_sem); all_tensors.input_tensor_from_remote_sync[direction].completion_target_value_per_semaphore.push_back(1); // remote output sync if (neighbour_devices[direction] != nullptr) { - all_tensors.remote_output_sync[direction].semaphore_ids.push_back(to_remote_sem.get()); + all_tensors.remote_output_sync[direction].semaphore_ids.push_back(&to_remote_sem); all_tensors.remote_output_sync[direction].completion_target_value_per_semaphore.push_back(1); all_tensors.remote_output_sync[direction] = all_tensors.input_tensor_from_remote_sync[direction]; all_tensors.remote_output_sync[direction].targets.back() = TensorSyncSpec::target_rect{ @@ -2157,10 +2157,9 @@ operation::ProgramWithCallbacks reduce_scatter_async_on_instantiated_edm_fabric( const uint32_t dim, const size_t num_links, ttnn::ccl::Topology topology, - fabric_lifetime_mode fabric_mode, - std::shared_ptr const& from_remote_sems, - std::shared_ptr const& to_remote_sem, + const GlobalSemaphore& from_remote_sems, + const GlobalSemaphore& to_remote_sem, const std::optional& sub_device_id) { using namespace ttnn::ccl::worker_detail; bool do_dynamic_fabric_bringup_and_teardown = fabric_mode == fabric_lifetime_mode::TRANSIENT; @@ -2484,8 +2483,8 @@ operation::ProgramWithCallbacks build_reduce_scatter_async_program( const uint32_t line_index, ttnn::ccl::Topology topology, std::optional num_links_preferred, - std::shared_ptr const& from_remote_sem, - std::shared_ptr const& to_remote_sem, + const tt::tt_metal::GlobalSemaphore& from_remote_sem, + const tt::tt_metal::GlobalSemaphore& to_remote_sem, const std::optional& sub_device_id, std::optional& fabric_handle_) { auto program = tt::tt_metal::Program(); diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp index a635e8e8e5c..4fcbe4e09ee 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/argmax/argmax.cpp @@ -68,7 +68,7 @@ Tensor ArgmaxOperation::invoke( output_memory_config); max_tensor = ttnn::add(max_tensor, max_val, std::nullopt, output_memory_config); } - tindex = tindex.to(input_a.device()); + tindex = tindex.to_device(input_a.device()); max_val.deallocate(); Tensor cmp_results = ttnn::eq(input_a, max_tensor, std::nullopt, output_memory_config); Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_memory_config); @@ -119,7 +119,7 @@ Tensor ArgmaxOperation::invoke( input_a.device(), output_memory_config); } - tindex = tindex.to(input_a.device()); + tindex = tindex.to_device(input_a.device()); Tensor max_indices = ttnn::multiply(cmp_results, tindex, std::nullopt, output_memory_config); cmp_results.deallocate(); Tensor midx = full_like(max_indices, size); diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp index fe4999ccd3d..8c8dda81309 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/prefix_scan/device/kernels/ssm_prefix_scan.cpp @@ -67,7 +67,7 @@ FORCE_INLINE void mul(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) { reconfig_data_format(cb_a, cb_b); pack_reconfig_data_format(cb_out); - mul_tiles_init(); + mul_tiles_init(cb_a, cb_b); cb_wait_front(cb_a, 1); cb_wait_front(cb_b, 1); @@ -89,7 +89,7 @@ FORCE_INLINE void sum(uint32_t cb_a, uint32_t cb_b, uint32_t cb_out) { reconfig_data_format(cb_a, cb_b); pack_reconfig_data_format(cb_out); - add_tiles_init(); + add_tiles_init(cb_a, cb_b); cb_wait_front(cb_a, 1); cb_wait_front(cb_b, 1); @@ -149,7 +149,7 @@ void MAIN { const uint32_t total_tiles_per_col = get_arg_val(2); const uint32_t num_chunks_per_row = get_arg_val(3); - binary_op_init_common(cb_a_in, cb_bx_in); + binary_op_init_common(cb_a_in, cb_bx_in, cb_out); const uint32_t num_tiles_last_chunk = total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK == 0 ? NUM_TILES_IN_TILIZED_CHUNK : total_tiles_per_row % NUM_TILES_IN_TILIZED_CHUNK; diff --git a/ttnn/cpp/ttnn/operations/experimental/ssm/repeat_and_interleave_eltwise_mul/device/kernels/ssm_eltwise_mul.cpp b/ttnn/cpp/ttnn/operations/experimental/ssm/repeat_and_interleave_eltwise_mul/device/kernels/ssm_eltwise_mul.cpp index 3dcf50f828d..0c9fbe44632 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ssm/repeat_and_interleave_eltwise_mul/device/kernels/ssm_eltwise_mul.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ssm/repeat_and_interleave_eltwise_mul/device/kernels/ssm_eltwise_mul.cpp @@ -25,9 +25,10 @@ void MAIN { constexpr uint32_t num_rows_in_one_tile = 32; #ifdef REPEAT_INTERLEAVE_IN1 - binary_op_init_common(cb_in0_transposed, cb_in1_bcast_row); // TODO: Is there a specific one for bcast mul? + binary_op_init_common( + cb_in0_transposed, cb_in1_bcast_row, cb_id_out); // TODO: Is there a specific one for bcast mul? #else - binary_op_init_common(cb_id_in0, cb_id_in1); + binary_op_init_common(cb_id_in0, cb_id_in1, cb_id_out); #endif for (uint32_t block_h_id = 0; block_h_id < in1_num_blocks_h; block_h_id++) { diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp index f5ba999b2c2..1d7d0ec5f71 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding/device/kernels/compute/rotary_embedding.cpp @@ -30,7 +30,7 @@ ALWI void MUL_TILES(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t // We don't pop in1 in decode which is sin/cos since we don't stream #else ACQ(); - mul_tiles_init(); + mul_tiles_init(in0_cb, in1_cb); mul_tiles(in0_cb, in1_cb, 0, 0, 0); pack_tile(0, out_cb); REL(); @@ -146,7 +146,7 @@ void MAIN { reconfig_data_format_srca(rotated_in_cb, cos_interm_cb); pack_reconfig_data_format(cos_interm_cb, out_cb); ACQ(); - add_tiles_init(); + add_tiles_init(cos_interm_cb, sin_interm_cb); add_tiles(cos_interm_cb, sin_interm_cb, 0, 0, 0); pack_tile(0, out_cb); REL(); diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama.cpp index b228c94016a..7d9dc699d61 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama.cpp @@ -37,7 +37,7 @@ void MAIN { const uint32_t my_cos_sin_tiles = my_seq_tiles * Wt; mm_init(); - binary_op_init_common(rotated_in_interm_cb, cos_cb); // General Init for all binary ops + binary_op_init_common(rotated_in_interm_cb, cos_cb, out_cb); // General Init for all binary ops // Get the trans_mat cb_wait_front(trans_mat_cb, onetile); @@ -77,7 +77,7 @@ void MAIN { cb_push_back(rotated_in_interm_cb, Wt); cb_wait_front(rotated_in_interm_cb, Wt); - mul_tiles_init(); + mul_tiles_init(rotated_in_interm_cb, sin_cb); ACQ(); for (uint32_t j = 0; j < Wt; ++j) { // sin_interim = rotated * sin @@ -104,7 +104,7 @@ void MAIN { cb_wait_front(sin_interm_cb, Wt); cb_wait_front(cos_interm_cb, Wt); - add_tiles_init(); + add_tiles_init(cos_interm_cb, sin_interm_cb); ACQ(); for (uint32_t j = 0; j < Wt; ++j) { // out = cos_interim + sin_interim diff --git a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp index bd749cf8a37..2a4c2562e73 100644 --- a/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/transformer/rotary_embedding_llama/device/kernels/compute/rotary_embedding_llama_sharded.cpp @@ -90,7 +90,7 @@ void MAIN { cb_wait_front(sin_interm_cb, Wt); cb_wait_front(cos_interm_cb, Wt); - add_tiles_init(); + add_tiles_init(cos_interm_cb, sin_interm_cb); ACQ(); for (uint32_t j = 0; j < Wt; ++j) { // out = cos_interim + sin_interim diff --git a/ttnn/cpp/ttnn/operations/functions.hpp b/ttnn/cpp/ttnn/operations/functions.hpp index 72fd6650fb5..f70c7bac474 100644 --- a/ttnn/cpp/ttnn/operations/functions.hpp +++ b/ttnn/cpp/ttnn/operations/functions.hpp @@ -66,9 +66,9 @@ static Tensor index_trilu( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -108,9 +108,9 @@ static Tensor index_width( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -150,9 +150,9 @@ static Tensor index_height( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -191,9 +191,9 @@ static Tensor index_all( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -227,9 +227,9 @@ static Tensor mask_padded_input( } // dim H } // dim C } // dim N - auto output = Tensor(OwnedStorage{owned_buffer}, padded_shape, data_type, Layout::ROW_MAJOR).to(layout); + auto output = Tensor(OwnedStorage{owned_buffer}, padded_shape, data_type, Layout::ROW_MAJOR).to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -270,9 +270,9 @@ static Tensor fill_first_val_into_tensor( MemoryConfig{}, input_tensor.get_logical_shape(), input_tensor.get_padded_shape()))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -329,9 +329,9 @@ static Tensor prod_result_computation_GS( MemoryConfig{}, input_tensor.get_logical_shape(), input_tensor.get_padded_shape()))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -392,9 +392,9 @@ static Tensor prod_result_computation_WH_B0( MemoryConfig{}, input_tensor.get_logical_shape(), input_tensor.get_padded_shape()))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -434,9 +434,9 @@ static Tensor index_channel( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -475,9 +475,9 @@ static Tensor index_batch( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -514,9 +514,9 @@ static Tensor manual_insertion( logical_shape, TensorLayout::fromPaddedShape( data_type, PageConfig(Layout::ROW_MAJOR), MemoryConfig{}, logical_shape, padded_shape))) - .to(layout); + .to_layout(layout); if (device != nullptr) { - output = output.to(device, output_mem_config); + output = output.to_device(device, output_mem_config); } return output; } @@ -578,7 +578,7 @@ static Tensor uniform(T low, T high, const ttnn::Shape& shape, const Layout layo } } - return Tensor(OwnedStorage{owned_buffer}, spec).to(layout); + return Tensor(OwnedStorage{owned_buffer}, spec).to_layout(layout); } static Tensor random( diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp index 79d25ce9c5c..73ef8d67cfb 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -198,6 +198,14 @@ void MAIN { // accumulation is done by iterating matmul_block across inner dim // in0_block_w is passed as innder dim (kt) to matmul_block, interally used to stride // in0 + +#ifdef ARCH_BLACKHOLE + // FIXME: This is a temporary workaround to avoid hangs on blackhole. + // https://github.com/tenstorrent/tt-metal/issues/16439 + for (uint32_t i = 0; i < 10; i++) { + asm volatile("nop"); + } +#endif matmul_block( in0_cb_id, in1_cb_id, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp index c7891faae03..a56029d83bb 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/kernels/dataflow/reader_bmm_tile_layout_in0_sender_receiver_padding_block_sharded.cpp @@ -270,7 +270,6 @@ void kernel_main() { // wait on in0 semaphore value to become VALID (set by mcast sender after it multicasts data) noc_semaphore_wait(in0_mcast_receiver_semaphore_addr_ptr, VALID); } - cb_push_back(cb_id_in0, in0_block_num_tiles); // If core does not produce output block work, free cb_id_in0 immediately. diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index a2fb8c9b43a..979eebd4233 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1608,19 +1608,22 @@ void Matmul::validate( } if (program_config.mcast_in0 || program_config.gather_in0) { if (input_tensor_a.is_sharded()) { - TT_FATAL(program_config.fuse_batch, "Error"); + TT_FATAL(program_config.fuse_batch, "Error: Batch fusion must be enabled."); TT_FATAL( - input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Error"); + input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED, + "Error: input_tensor_a must be width sharded. Provided tensor memory layout: {}", + input_tensor_a.memory_config().memory_layout); if (this->output_mem_config.is_sharded()) { TT_FATAL( input_tensor_a.memory_config().buffer_type == this->output_mem_config.buffer_type, - "Error"); + "Error: Buffer type mismatch."); TT_FATAL( input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout, - "Error"); + "Error: Memory layout mismatch."); } TT_FATAL( - input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, "Error"); + input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Error: Shard orientation must be ROW_MAJOR."); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1] : input_tensor_a.get_padded_shape()[-2]) / @@ -1632,15 +1635,30 @@ void Matmul::validate( auto shard_shape = input_tensor_a.shard_spec().value().shape; // No padding - TT_FATAL(M == per_core_M, "Error"); - TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); - TT_FATAL(K % program_config.in0_block_w == 0, "Error"); + TT_FATAL(M == per_core_M, "Error: M ({}) must be equal to per_core_M ({}).", M, per_core_M); + TT_FATAL( + per_core_M == (shard_shape[0] / in0_tile_shape[0]), + "Error: per_core_M must be equal to shard_shape[0] ({}) / in0_tile_shape[0] ({}).", + shard_shape[0], + in0_tile_shape[0]); + TT_FATAL( + K % program_config.in0_block_w == 0, + "Error: K {} must be divisible by in0_block_w {}.", + K, + program_config.in0_block_w); if (!program_config.gather_in0) { // Padding allowed for gather_in0 - TT_FATAL((shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, "Error"); + TT_FATAL( + (shard_shape[1] / in0_tile_shape[1]) % program_config.in0_block_w == 0, + "Error: shard_shape[1] ({}) / in0_tile_shape[1] ({}) must be divisible by in0_block_w.", + shard_shape[1], + in0_tile_shape[1]); } } if (this->output_mem_config.is_sharded()) { - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED, "Error"); + TT_FATAL( + this->output_mem_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED, + "Error: Output memory layout must be WIDTH_SHARDED. Provided tensor memory layout: {}", + this->output_mem_config.memory_layout); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1] : input_tensor_a.get_padded_shape()[-2]) / @@ -1650,10 +1668,11 @@ void Matmul::validate( uint32_t per_core_N = program_config.per_core_N; // No padding - TT_FATAL(M == per_core_M, "Error"); + TT_FATAL(M == per_core_M, "Error: M {} must be equal to per_core_M {}.", M, per_core_M); TT_FATAL( - program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1, "Error"); + program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1, + "Error: out_subblock_w must be equal to per_core_N or out_subblock_h must be equal to 1."); } if (input_tensor_b.buffer()->buffer_type() == tt_metal::BufferType::L1 && input_tensor_b.memory_config().is_sharded()) { @@ -1674,20 +1693,21 @@ void Matmul::validate( } } else { if (input_tensor_a.memory_config().is_sharded()) { - TT_FATAL(program_config.fuse_batch, "Error"); + TT_FATAL(program_config.fuse_batch, "Error: Batch fusion must be enabled."); TT_FATAL( input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, - "Error"); + "Error: input_tensor_a must be height sharded."); if (this->output_mem_config.is_sharded()) { TT_FATAL( input_tensor_a.memory_config().buffer_type == this->output_mem_config.buffer_type, - "Error"); + "Error: Buffer type mismatch."); TT_FATAL( input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout, - "Error"); + "Error: Memory layout mismatch."); } TT_FATAL( - input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, "Error"); + input_tensor_a.shard_spec().value().orientation == ShardOrientation::ROW_MAJOR, + "Error: Shard orientation must be ROW_MAJOR."); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1] : input_tensor_a.get_padded_shape()[-2]) / @@ -1696,13 +1716,20 @@ void Matmul::validate( uint32_t per_core_M = program_config.per_core_M; auto shard_shape = input_tensor_a.shard_spec().value().shape; TT_FATAL( - div_up(M, per_core_M) <= input_tensor_a.shard_spec().value().grid.num_cores(), "Error"); - TT_FATAL(per_core_M == (shard_shape[0] / in0_tile_shape[0]), "Error"); - TT_FATAL(K % program_config.in0_block_w == 0, "Error"); - TT_FATAL(K == (shard_shape[1] / in0_tile_shape[1]), "Error"); + div_up(M, per_core_M) <= input_tensor_a.shard_spec().value().grid.num_cores(), + "Error: M must be divisible by per_core_M."); + TT_FATAL( + per_core_M == (shard_shape[0] / in0_tile_shape[0]), + "Error: per_core_M must be equal to shard_shape[0] / in0_tile_shape[0]."); + TT_FATAL(K % program_config.in0_block_w == 0, "Error: K must be divisible by in0_block_w."); + TT_FATAL( + K == (shard_shape[1] / in0_tile_shape[1]), + "Error: K must be equal to shard_shape[1] / in0_tile_shape[1]."); } if (this->output_mem_config.is_sharded()) { - TT_FATAL(this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, "Error"); + TT_FATAL( + this->output_mem_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED, + "Error: Output memory layout must be HEIGHT_SHARDED."); uint32_t M = (program_config.fuse_batch ? input_tensor_a.volume() / input_tensor_a.get_padded_shape()[-1] : input_tensor_a.get_padded_shape()[-2]) / @@ -1711,11 +1738,14 @@ void Matmul::validate( uint32_t per_core_M = program_config.per_core_M; uint32_t per_core_N = program_config.per_core_N; - TT_FATAL(N == per_core_N, "Error"); + TT_FATAL(N == per_core_N, "Error: N must be equal to per_core_N."); TT_FATAL( - program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1, "Error"); + program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1, + "Error: out_subblock_w must be equal to per_core_N or out_subblock_h must be equal to 1."); } - TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); + TT_FATAL( + input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "Error: Operand B must be interleaved."); } } else if constexpr (std::is_same_v< ProgramConfigType, @@ -2001,7 +2031,7 @@ std::vector Matmul::compute_output_specs( auto tile_width_ratio = output_tile.get_tile_shape()[1] / in1_tile_shape[1]; auto output_layout = this->untilize_out ? Layout::ROW_MAJOR : Layout::TILE; - TT_FATAL(this->output_dtype.has_value(), "Error"); + TT_FATAL(this->output_dtype.has_value(), "Error: output_dtype field should have been populated"); if (this->output_mem_config.is_sharded()) { const auto& optional_bias = optional_input_tensors.at(0); uint32_t bias_single_tile_size = 0; @@ -2171,15 +2201,15 @@ operation::ProgramWithCallbacks Matmul::create_program( const auto& bias = optional_input_tensors.at(0); auto& output_tensor = output_tensors.at(0); - TT_FATAL(this->output_dtype.has_value(), "Error"); + TT_FATAL(this->output_dtype.has_value(), "Error: output_dtype field should have been populated"); tt::tt_metal::DataType output_dtype = this->output_dtype.value(); bool fuse_batch = true; // TODO: If input_tensor_a.get_padded_shape()[0] * input_tensor_a.get_padded_shape()[1] * ... except last two // dimensions == 1, does matmuls work if we treat it as bmm // TODO: Only for MatmulMultiCoreReuseProgramConfig we allow this as single core matmul/bmm - TT_FATAL(this->compute_kernel_config.has_value(), "Error"); - TT_FATAL(this->bcast_batch.has_value(), "Error"); + TT_FATAL(this->compute_kernel_config.has_value(), "Error: compute_kernel_config field should have been populated"); + TT_FATAL(this->bcast_batch.has_value(), "Error: bcast_batch field should have been populated"); bool broadcast_batch = this->bcast_batch.value(); uint32_t bias_single_tile_size = 0; if (bias.has_value()) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_abs_pow/device/kernels/moreh_abs_pow_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_abs_pow/device/kernels/moreh_abs_pow_kernel.cpp index d25ae5a08eb..763550bf151 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_abs_pow/device/kernels/moreh_abs_pow_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_abs_pow/device/kernels/moreh_abs_pow_kernel.cpp @@ -37,7 +37,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0, tt::CB::c_out0); cb_wait_front(cb_one, onetile); // comes from the reader cb_wait_front(cb_decimal, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/moreh_adam.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/moreh_adam.cpp index fed8e7ad995..9e6af27e80b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/moreh_adam.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/kernels/moreh_adam.cpp @@ -63,7 +63,7 @@ void MAIN { cb_wait_front(cb_scalar_args, 5); cb_wait_front(cb_one, onetile); - binary_op_init_common(cb_param_in, cb_scalar_args); + binary_op_init_common(cb_param_in, cb_scalar_args, cb_param_out); for (uint32_t b = 0; b < per_core_tile_cnt; ++b) { // grad += grad + param * weight_decay; @@ -140,7 +140,7 @@ void MAIN { cb_wait_front(cb_tmp1, onetile); cb_reserve_back(cb_tmp1, onetile); WITH_FP32_DEST_ACC(reconfig_data_format(cb_one, cb_tmp1)); - sub_tiles_init(); + sub_tiles_init(cb_one, cb_tmp1); sub_tiles(cb_one, cb_tmp1, first_tile, first_tile, dst0); recip_tile_init(); recip_tile(dst0); @@ -189,11 +189,11 @@ void MAIN { cb_reserve_back(cb_tmp1, onetile); #ifdef AMSGRAD - mul_tiles_init(); + mul_tiles_init(tmp_cb_max_exp_avg_sq, cb_tmp1); WITH_FP32_DEST_ACC(reconfig_data_format(tmp_cb_max_exp_avg_sq, cb_tmp1)); mul_tiles(tmp_cb_max_exp_avg_sq, cb_tmp1, first_tile, first_tile, dst0); #else - mul_tiles_init(); + mul_tiles_init(tmp_cb_exp_avg_sq, cb_tmp1); WITH_FP32_DEST_ACC(reconfig_data_format(tmp_cb_exp_avg_sq, cb_tmp1)); mul_tiles(tmp_cb_exp_avg_sq, cb_tmp1, first_tile, first_tile, dst0); #endif @@ -216,7 +216,7 @@ void MAIN { cb_wait_front(cb_tmp1, onetile); cb_reserve_back(cb_tmp1, onetile); WITH_FP32_DEST_ACC(reconfig_data_format(cb_tmp1, cb_scalar_args)); - add_tiles_init(); + add_tiles_init(cb_tmp1, cb_scalar_args); add_tiles(cb_tmp1, cb_scalar_args, first_tile, eps_tile, dst0); recip_tile_init(); recip_tile(dst0); @@ -247,7 +247,7 @@ void MAIN { tile_regs_acquire(); cb_wait_front(cb_tmp2, onetile); WITH_FP32_DEST_ACC(reconfig_data_format(cb_one, cb_tmp2)); - sub_tiles_init(); + sub_tiles_init(cb_one, cb_tmp2); sub_tiles(cb_one, cb_tmp2, first_tile, first_tile, dst0); recip_tile_init(); recip_tile(dst0); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/kernels/moreh_adamw.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/kernels/moreh_adamw.cpp index 85d2f3717af..e444c212870 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/kernels/moreh_adamw.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/kernels/moreh_adamw.cpp @@ -61,7 +61,7 @@ void MAIN { cb_wait_front(cb_beta1_exponent, onetile); cb_wait_front(cb_beta2_exponent, onetile); - binary_op_init_common(cb_param_in, cb_scalar_args); + binary_op_init_common(cb_param_in, cb_scalar_args, cb_param_out); for (uint32_t b = 0; b < per_core_tile_cnt; ++b) { cb_wait_front(cb_param_in, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp index c98edddc237..96abcf904f5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/device/kernels/moreh_clip_grad_norm_step1_kernel.cpp @@ -45,7 +45,7 @@ void MAIN { const auto ht = (origin_h + TILE_H - 1) / TILE_H; const auto wt = (origin_w + TILE_W - 1) / TILE_W; - binary_op_init_common(cb_logx, cb_decimal); + binary_op_init_common(cb_logx, cb_decimal, cb_y); cb_wait_front(cb_decimal, onetile); // comes from the reader cb_wait_front(cb_one, onetile); // comes from the reader @@ -115,7 +115,7 @@ void MAIN { cb_wait_front(cb_xpowadd, onetile); cb_reserve_back(cb_xpowadd, onetile); - add_tiles_init(); + add_tiles_init(cb_correct_xpow, cb_xpowadd); add_tiles(cb_correct_xpow, cb_xpowadd, 0, 0, dst0); tile_regs_commit(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp index f461001ca6a..ee15277ef6f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/device/kernels/moreh_clip_grad_norm_step2_kernel.cpp @@ -30,9 +30,9 @@ void MAIN { constexpr uint32_t dst0 = 0; if (num_tiles > 1) { - binary_op_init_common(cb_input, cb_x); + binary_op_init_common(cb_input, cb_x, cb_y); } else { - binary_op_init_common(cb_logx, cb_decimal); + binary_op_init_common(cb_logx, cb_decimal, cb_y); } cb_wait_front(cb_decimal, onetile); // comes from the reader @@ -59,7 +59,7 @@ void MAIN { cb_wait_front(cb_x, onetile); cb_reserve_back(cb_x, onetile); - add_tiles_init(); + add_tiles_init(cb_input, cb_x); add_tiles(cb_input, cb_x, 0, 0, dst0); cb_pop_front(cb_x, onetile); cb_pop_front(cb_input, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp index 25c14d12c1c..25f060c633d 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/device/kernels/moreh_clip_grad_norm_step3_kernel.cpp @@ -19,7 +19,7 @@ void MAIN { constexpr uint32_t onetile = 1; constexpr uint32_t dst0 = 0; - binary_op_init_common(cb_x, cb_clip_coef_clamped); + binary_op_init_common(cb_x, cb_clip_coef_clamped, cb_y); cb_wait_front(cb_clip_coef_clamped, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp index 41b73687a9a..918b48aa3b1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp @@ -23,7 +23,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t first_tile = 0; - binary_op_init_common(cb_in0, cb_in1); + binary_op_init_common(cb_in0, cb_in1, cb_out0); cb_wait_front(cb_in1, onetile); for (uint32_t i = 0; i < num_output_tiles_per_core; i++) { @@ -33,7 +33,7 @@ void MAIN { uint32_t cb_add = (enable_reload) ? (cb_intermed0) : (cb_in1); cb_wait_front(cb_in0, onetile); - add_tiles_init(); + add_tiles_init(cb_in0, cb_add); add_tiles(cb_in0, cb_add, first_tile, first_tile, dst0); cb_pop_front(cb_in0, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/kernels/moreh_dot.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/kernels/moreh_dot.cpp index 8e453f62917..c1489081079 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/kernels/moreh_dot.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/kernels/moreh_dot.cpp @@ -15,7 +15,7 @@ namespace NAMESPACE { void MAIN { constexpr int onetile = 1; uint32_t per_core_block_cnt = get_arg_val(0); - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); bool enable_reload = false; for (uint32_t block = 0; block < per_core_block_cnt; ++block) { bool last_out = block == (per_core_block_cnt - 1); @@ -26,7 +26,7 @@ void MAIN { cb_wait_front(tt::CBIndex::c_1, onetile); cb_reserve_back(tt::CBIndex::c_24, onetile); - mul_tiles_init(); + mul_tiles_init(tt::CBIndex::c_0, tt::CBIndex::c_1); // dst0 = c_in0 x c_in1 mul_tiles(tt::CBIndex::c_0, tt::CBIndex::c_1, 0, 0, 0); // c_intermed0 = pack(dst0) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp index 27970e685ae..a0433a93503 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_large_kernel.cpp @@ -22,7 +22,7 @@ void MAIN { constexpr bool is_lastdim_layernorm = get_compile_time_arg_val(9) == 1; constexpr bool is_groupnorm = get_compile_time_arg_val(10) == 1; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); constexpr auto cb_x = tt::CBIndex::c_0; // input constexpr auto cb_scaler = tt::CBIndex::c_1; // scaler diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp index 84c55e29650..0e242ab6858 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/kernels/moreh_layer_norm_small_kernel.cpp @@ -22,7 +22,7 @@ void MAIN { constexpr bool is_lastdim_layernorm = get_compile_time_arg_val(9) == 1; constexpr bool is_groupnorm = get_compile_time_arg_val(10) == 1; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); constexpr auto cb_x = tt::CBIndex::c_0; // input constexpr auto cb_scaler = tt::CBIndex::c_1; // scaler diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_large_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_large_kernel.cpp index 00f6f43a330..fccbf6af828 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_large_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_large_kernel.cpp @@ -18,7 +18,7 @@ void MAIN { constexpr bool is_lastdim_layernorm = get_compile_time_arg_val(5) == 1; constexpr bool is_groupnorm = get_compile_time_arg_val(6) == 1; - binary_op_init_common(tt::CBIndex::c_1, tt::CBIndex::c_2); + binary_op_init_common(tt::CBIndex::c_1, tt::CBIndex::c_2, tt::CBIndex::c_16); constexpr auto cb_dy = tt::CBIndex::c_0; // output_grad(==dy) constexpr auto cb_x = tt::CBIndex::c_1; // input(==x) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_small_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_small_kernel.cpp index ac958841018..e4b0c1c4675 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_small_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/kernels/moreh_layer_norm_backward_input_grad_small_kernel.cpp @@ -18,7 +18,7 @@ void MAIN { constexpr bool is_lastdim_layernorm = get_compile_time_arg_val(5) == 1; constexpr bool is_groupnorm = get_compile_time_arg_val(6) == 1; - binary_op_init_common(tt::CBIndex::c_1, tt::CBIndex::c_2); + binary_op_init_common(tt::CBIndex::c_1, tt::CBIndex::c_2, tt::CBIndex::c_16); constexpr auto cb_dy = tt::CBIndex::c_0; // output_grad(==dy) constexpr auto cb_x = tt::CBIndex::c_1; // input(==x) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_h.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_h.cpp index b41e194d279..9a69ebae8e2 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_h.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_h.cpp @@ -25,7 +25,7 @@ void MAIN { constexpr auto cb_out = tt::CBIndex::c_16; constexpr bool do_mask_h = (origin_H % TILE_HEIGHT) != 0; - binary_op_init_common(cb_input, cb_input); + binary_op_init_common(cb_input, cb_input, cb_out); cb_wait_front(cb_scaler, 1); // scaler tile from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_nc.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_nc.cpp index f0bbc3b6cb5..0211eec081f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_nc.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_nc.cpp @@ -24,7 +24,7 @@ void MAIN { constexpr uint32_t dst1 = 1; constexpr uint32_t first_tile = 0; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); cb_wait_front(cb_in1, onetile); cb_wait_front(cb_scalar, 1); // scalar tile from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp index 8e8ce12e68f..99758bace92 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp @@ -26,7 +26,7 @@ void MAIN { constexpr auto cb_out = tt::CBIndex::c_16; constexpr bool do_mask_w = (origin_W % TILE_WIDTH) != 0; - binary_op_init_common(cb_input, cb_input); + binary_op_init_common(cb_input, cb_input, cb_out); cb_wait_front(cb_scaler, 1); // scaler tile from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/kernels/moreh_mean_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/kernels/moreh_mean_backward.cpp index b6cf5095db5..49977ba818c 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/kernels/moreh_mean_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/kernels/moreh_mean_backward.cpp @@ -24,7 +24,7 @@ void MAIN { constexpr uint32_t onetile = 1; constexpr uint32_t dst0 = 0; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); cb_wait_front(cb_in1, onetile); for (uint32_t i = 0; i < num_output_tiles; i++) { tile_regs_acquire(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/kernels/moreh_nll_loss_step2_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/kernels/moreh_nll_loss_step2_kernel.cpp index 30e145b6086..605cc003420 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/kernels/moreh_nll_loss_step2_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/kernels/moreh_nll_loss_step2_kernel.cpp @@ -24,7 +24,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t onetile = 1; - binary_op_init_common(cb_tmp_weight, cb_tmp_input); + binary_op_init_common(cb_tmp_weight, cb_tmp_input, cb_output); #if defined(DIVISOR) cb_wait_front(cb_divisor, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp index 2eb853b5d31..1509ee23cd5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp @@ -46,7 +46,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); cb_wait_front(cb_one, onetile); // comes from the reader cb_wait_front(cb_decimal, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_other/kernels/moreh_norm_other_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_other/kernels/moreh_norm_other_kernel.cpp index 22b9f82d986..629f2e665e2 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_other/kernels/moreh_norm_other_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_other/kernels/moreh_norm_other_kernel.cpp @@ -42,7 +42,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); cb_wait_front(cb_one, onetile); // comes from the reader cb_wait_front(cb_decimal, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp index 65a6c1f51a1..3674335b84a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp @@ -46,7 +46,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); cb_wait_front(cb_one, onetile); // comes from the reader cb_wait_front(cb_decimal, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp index 5dd80e588dc..1a8b75610c3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_h/kernels/moreh_norm_h_kernel.cpp @@ -31,7 +31,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0, tt::CB::c_out0); cb_wait_front(cb_one, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_nc/kernels/moreh_norm_nc_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_nc/kernels/moreh_norm_nc_kernel.cpp index a1393e93cd3..b654e463e80 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_nc/kernels/moreh_norm_nc_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_nc/kernels/moreh_norm_nc_kernel.cpp @@ -27,7 +27,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0, tt::CB::c_out0); cb_wait_front(cb_one, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp index dc721c558cc..6a0b2ec02ab 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/ord_other/moreh_norm_w/kernels/moreh_norm_w_kernel.cpp @@ -31,7 +31,7 @@ void MAIN { constexpr uint32_t dst0 = 0; constexpr uint32_t dst1 = 1; - binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0); + binary_op_init_common(tt::CB::c_in0, tt::CB::c_in0, tt::CB::c_out0); cb_wait_front(cb_one, onetile); // comes from the reader diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/kernels/moreh_norm_backward_kernel.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/kernels/moreh_norm_backward_kernel.cpp index edf232d8597..de3ec556181 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/kernels/moreh_norm_backward_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/kernels/moreh_norm_backward_kernel.cpp @@ -48,7 +48,7 @@ void MAIN { constexpr uint32_t onetile = 1; constexpr uint32_t dst0 = 0; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_0, tt::CBIndex::c_16); cb_wait_front(cb_decimal, onetile); // comes from the reader for (uint32_t idx = 0; idx < num_input_tiles_per_core; ++idx) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/kernels/moreh_sgd.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/kernels/moreh_sgd.cpp index 175e4072cc1..8e207cb4b8c 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/kernels/moreh_sgd.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/kernels/moreh_sgd.cpp @@ -26,7 +26,7 @@ void MAIN { constexpr uint32_t weight_decay_tile = 3; constexpr uint32_t one_tile = 4; - binary_op_init_common(cb_param_in, cb_param_in); + binary_op_init_common(cb_param_in, cb_param_in, cb_param_out); uint32_t num_tiles = get_compile_time_arg_val(0); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_c_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_c_large.cpp index 583b9b3db7c..ee207f3f5fc 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_c_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_c_large.cpp @@ -26,7 +26,7 @@ void MAIN { uint32_t N = get_compile_time_arg_val(0); uint32_t dim_size = get_compile_time_arg_val(1); - binary_op_init_common(cb_in0, cb_exps); + binary_op_init_common(cb_in0, cb_exps, cb_out0); for (uint32_t n = 0; n < N; ++n) { // find max diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h.cpp index 653d790fc70..945c0b0fb78 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h.cpp @@ -25,7 +25,7 @@ void MAIN { constexpr int dst1 = 1; constexpr uint32_t onetile = 1; - binary_op_init_common(cb_in0, cb_bcast_scaler); + binary_op_init_common(cb_in0, cb_bcast_scaler, cb_out0); uint32_t N = get_compile_time_arg_val(0); uint32_t Ht = get_compile_time_arg_val(1); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h_large.cpp index a9ee25ff14f..efa93ef0310 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_h_large.cpp @@ -21,7 +21,7 @@ void MAIN { constexpr auto cb_max = tt::CBIndex::c_27; constexpr auto cb_tmp = tt::CBIndex::c_28; - binary_op_init_common(cb_in0, cb_bcast_scaler); + binary_op_init_common(cb_in0, cb_bcast_scaler, cb_out0); constexpr uint32_t onetile = 1; constexpr int dst0 = 0; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w.cpp index 415974b0814..05dec994a8c 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w.cpp @@ -22,7 +22,7 @@ void MAIN { constexpr auto cb_x_m_max = tt::CBIndex::c_27; constexpr auto cb_tmp = tt::CBIndex::c_28; - binary_op_init_common(cb_in0, cb_bcast_scaler); + binary_op_init_common(cb_in0, cb_bcast_scaler, cb_out0); constexpr int dst0 = 0; constexpr int dst1 = 1; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w_large.cpp index 0e91c5bdae0..ab7446582d4 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/kernels/moreh_softmax_w_large.cpp @@ -21,7 +21,7 @@ void MAIN { constexpr auto cb_max = tt::CBIndex::c_27; constexpr auto cb_tmp = tt::CBIndex::c_28; - binary_op_init_common(cb_in0, cb_bcast_scaler); + binary_op_init_common(cb_in0, cb_bcast_scaler, cb_out0); constexpr uint32_t onetile = 1; constexpr int dst0 = 0; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_c_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_c_large.cpp index a6c6c635d6d..5b093211e72 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_c_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_c_large.cpp @@ -25,7 +25,7 @@ void MAIN { uint32_t N = get_compile_time_arg_val(0); uint32_t dim_size = get_compile_time_arg_val(1); - binary_op_init_common(cb_dy, cb_y); + binary_op_init_common(cb_dy, cb_y, cb_dx); constexpr int dst0 = 0; for (uint32_t n = 0; n < N; ++n) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h.cpp index 82ed70f676d..6499df2ac8b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h.cpp @@ -23,7 +23,7 @@ void MAIN { constexpr auto cb_sum = tt::CBIndex::c_25; constexpr auto cb_inter2 = tt::CBIndex::c_26; - binary_op_init_common(cb_y, cb_bcast_scaler); + binary_op_init_common(cb_y, cb_bcast_scaler, cb_dx); uint32_t N = get_compile_time_arg_val(0); uint32_t Ht = get_compile_time_arg_val(1); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h_large.cpp index aa9e128f887..5be9b8f1328 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_h_large.cpp @@ -24,7 +24,7 @@ void MAIN { constexpr auto cb_inter2 = tt::CBIndex::c_26; constexpr auto cb_add = tt::CBIndex::c_27; - binary_op_init_common(cb_y, cb_bcast_scaler); + binary_op_init_common(cb_y, cb_bcast_scaler, cb_dx); uint32_t N = get_compile_time_arg_val(0); uint32_t Ht = get_compile_time_arg_val(1); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w.cpp index cdc41f87694..ae131be13ad 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w.cpp @@ -23,7 +23,7 @@ void MAIN { constexpr auto cb_sum = tt::CBIndex::c_25; constexpr auto cb_inter2 = tt::CBIndex::c_26; - binary_op_init_common(cb_y, cb_bcast_scaler); + binary_op_init_common(cb_y, cb_bcast_scaler, cb_dx); uint32_t N = get_compile_time_arg_val(0); uint32_t Wt = get_compile_time_arg_val(1); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w_large.cpp index bc6b4afc251..ce81f470bc9 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/kernels/moreh_softmax_backward_w_large.cpp @@ -24,7 +24,7 @@ void MAIN { constexpr auto cb_inter2 = tt::CBIndex::c_26; constexpr auto cb_add = tt::CBIndex::c_27; - binary_op_init_common(cb_y, cb_bcast_scaler); + binary_op_init_common(cb_y, cb_bcast_scaler, cb_dx); uint32_t N = get_compile_time_arg_val(0); uint32_t Wt = get_compile_time_arg_val(1); diff --git a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp index 5a7525b7a4e..a58dedc3697 100644 --- a/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/batch_norm/device/kernels/compute/batch_norm_kernel.cpp @@ -159,9 +159,9 @@ void MAIN { auto cb_bcast = cb_batch_mean; auto cb_other = cb_input; - binary_op_init_common(cb_bcast, cb_other, cb_output_0); + binary_op_init_common(cb_other, cb_bcast, cb_output_0); - sub_tiles_init(); + sub_tiles_init(cb_other, cb_bcast); uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq; uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq; for (uint32_t i = 0; i < complete_iterations; ++i, tile_start = 0) { diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/kernels/compute/groupnorm_sharded_v2.cpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/kernels/compute/groupnorm_sharded_v2.cpp index b864a5a1467..69387a6aa57 100644 --- a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/kernels/compute/groupnorm_sharded_v2.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/kernels/compute/groupnorm_sharded_v2.cpp @@ -168,7 +168,7 @@ void MAIN { // mask input index_h_offset = index_b_offset + index_g_offset; reconfig_data_format_srcb(cb_in0, cb_input_mask); - mul_tiles_init(); + mul_tiles_init(cb_in0, cb_input_mask); cb_reserve_back(cb_x, block_hw); cb_wait_front(cb_input_mask, block_w); for (uint32_t i = 0; i < block_h; ++i) { @@ -263,7 +263,7 @@ void MAIN { // zero out the garbage values by mult mask again reconfig_data_format_srcb(cb_ex_global, cb_input_mask); - mul_tiles_init(); + mul_tiles_init(cb_xmm, cb_input_mask); cb_reserve_back(cb_x, block_hw); cb_wait_front(cb_xmm, block_hw); for (uint32_t i = 0; i < block_h; i++) { @@ -291,7 +291,7 @@ void MAIN { // (x - E[x])^2 index_h_offset = 0; - mul_tiles_init(); + mul_tiles_init(cb_x, cb_x); cb_reserve_back(cb_xmm, block_hw); cb_wait_front(cb_x, block_hw); for (uint32_t i = 0; i < block_h; i++) { @@ -360,7 +360,7 @@ void MAIN { cb_reserve_back(cb_ex2pe, 1); // (Var + eps) tile_regs_acquire(); - add_tiles_init(); + add_tiles_init(cb_ex_global, cb_eps); add_tiles(cb_ex_global, cb_eps, 0, 0, dst0); tile_regs_wait(); // sqrt(Var + eps) @@ -415,7 +415,7 @@ void MAIN { if (copy_or_add == true) { copy_tile_init(cb_xmm); } else { - add_tiles_init(); + add_tiles_init(cb_out, cb_xmm); } for (uint32_t i = 0; i < block_h; ++i) { diff --git a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp index 37af5cf65c0..f7dcadbca09 100644 --- a/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/groupnorm/device/multi_core/groupnorm_op_multi_core.cpp @@ -171,7 +171,7 @@ operation::ProgramWithCallbacks groupnorm_multi_core_sharded( uint32_t per_core_N = a.shard_spec().value().shape[1]; uint32_t per_core_Mt = per_core_M / TILE_HEIGHT; uint32_t per_core_Nt = (per_core_N + TILE_WIDTH - 1) / TILE_WIDTH; - uint32_t per_core_N_bytes_padded = round_up_to_mul32(per_core_N * datum_size_bytes); + uint32_t per_core_N_bytes_padded = tt::round_up(per_core_N * datum_size_bytes, output.buffer()->alignment()); bool reader_repack_output = (per_core_N % TILE_WIDTH) != 0; bool tilize_in = a.get_layout() == Layout::ROW_MAJOR; bool untilize_out = output.get_layout() == Layout::ROW_MAJOR; diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm.cpp index b83a459b4cf..a8e2dfaf501 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm.cpp @@ -81,7 +81,7 @@ void MAIN { #ifdef FUSE_PRE_ADD reconfig_data_format(cb_in, cb_inb); pack_reconfig_data_format(cb_x); - add_tiles_init(); + add_tiles_init(cb_in, cb_inb); for (uint32_t wt = 0; wt < Wt; wt += blk) { ACQ(); // UNPACK(( { DPRINT << "Waiting on cb_x" << ENDL(); } )); @@ -166,7 +166,7 @@ void MAIN { /* (x - E[x])^2 * compute temp = xmm*xmm = (x-E[x])^2 */ - mul_tiles_init(); + mul_tiles_init(cb_xmm, cb_xmm); for (uint32_t wt = 0; wt < Wt; wt += blk) { cb_wait_front(cb_xmm, wt + blk); // cumulative wait cb_reserve_back(cb_xmm2, blk); // can probably use less space for this if we block @@ -219,7 +219,7 @@ void MAIN { reconfig_data_format(cb_ex2, cb_eps); } ACQ(); - add_tiles_init(); + add_tiles_init(cb_ex2, cb_eps); add_tiles(cb_ex2, cb_eps, 0, 0, dst0); cb_reserve_back(cb_ex2pe, 1); // 1 diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded.cpp index fcf3550f6bb..bd2dd31df22 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded.cpp @@ -106,7 +106,7 @@ void MAIN { // pre-add x + y #ifdef FUSE_PRE_ADD reconfig_data_format_srcb(cb_in0, cb_in1); - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); cb_reserve_back(cb_in, num_tiles_per_block); for (uint32_t i = 0; i < block_h; i++) { index_subblock_w_offset = 0; @@ -224,7 +224,7 @@ void MAIN { #endif // (x - E[x])^2, cb_mm2 <-- cb_xmm - mul_tiles_init(); + mul_tiles_init(cb_xmm, cb_xmm); index_h_offset = 0; cb_reserve_back(cb_xmm2, num_tiles_per_block); for (uint32_t i = 0; i < block_h; i++) { @@ -311,7 +311,7 @@ void MAIN { cb_wait_front(cb_ex2, 1); cb_reserve_back(cb_ex2pe, 1); tile_regs_acquire(); - add_tiles_init(); + add_tiles_init(cb_ex2, cb_eps); add_tiles(cb_ex2, cb_eps, i, 0, dst0); tile_regs_wait(); // sqrt(Var + eps) diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_post_allgather.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_post_allgather.cpp index bfb5519f0ec..4a7afad8ad6 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_post_allgather.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_post_allgather.cpp @@ -145,7 +145,7 @@ void MAIN { cb_reserve_back(cb_ex_sqr, 1); cb_wait_front(cb_stats_reduced, 1); tile_regs_acquire(); - mul_tiles_init(); + mul_tiles_init(cb_stats_reduced, cb_stats_reduced); mul_tiles(cb_stats_reduced, cb_stats_reduced, 0, 0, dst0); // first tile in stats is always E(x) tile_regs_commit(); tile_regs_wait(); @@ -161,7 +161,7 @@ void MAIN { cb_wait_front(cb_ex_sqr, 1); cb_reserve_back(cb_var, 1); tile_regs_acquire(); - sub_tiles_init(); + sub_tiles_init(cb_ex2, cb_ex_sqr); sub_tiles(cb_ex2, cb_ex_sqr, 0, 0, dst0); tile_regs_commit(); tile_regs_wait(); @@ -179,7 +179,7 @@ void MAIN { cb_wait_front(cb_eps, 1); cb_reserve_back(cb_stats_reduced, 1); - add_tiles_init(); + add_tiles_init(cb_var, cb_eps); tile_regs_acquire(); add_tiles(cb_var, cb_eps, 0, 0, dst0); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_pre_allgather.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_pre_allgather.cpp index de011b09f7b..587c9e757f0 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_pre_allgather.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm/device/kernels/compute/layernorm_sharded_pre_allgather.cpp @@ -83,7 +83,7 @@ void MAIN { // pre-add x + y #ifdef FUSE_PRE_ADD binary_op_init_common(cb_in0, cb_in1, cb_in); - add_tiles_init(); + add_tiles_init(cb_in0, cb_in1); cb_reserve_back(cb_in, num_tiles_per_block); for (uint32_t i = 0; i < block_h; i++) { index_subblock_w_offset = 0; @@ -143,7 +143,7 @@ void MAIN { #endif // not RMSNORM // X^2 - mul_tiles_init(); + mul_tiles_init(cb_in0, cb_in0); index_h_offset = 0; cb_reserve_back(cb_x2, num_tiles_per_block); for (uint32_t i = 0; i < block_h; i++) { diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/kernels/compute/layernorm_post_allgather.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/kernels/compute/layernorm_post_allgather.cpp index f25498bbb87..2b4a5ecf1d0 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/kernels/compute/layernorm_post_allgather.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/kernels/compute/layernorm_post_allgather.cpp @@ -125,7 +125,7 @@ void MAIN { */ reconfig_data_format(cb_stats_reduced, cb_stats_reduced); pack_reconfig_data_format(cb_mean_squared); - mul_tiles_init(); + mul_tiles_init(cb_stats_reduced, cb_stats_reduced); cb_reserve_back(cb_mean_squared, onetile); cb_wait_front(cb_stats_reduced, stats_tile_stride); ACQ(); @@ -140,7 +140,7 @@ void MAIN { */ reconfig_data_format(cb_stats_reduced, cb_mean_squared); pack_reconfig_data_format(cb_var); - sub_tiles_init(); + sub_tiles_init(cb_stats_reduced, cb_mean_squared); cb_reserve_back(cb_var, onetile); cb_wait_front(cb_mean_squared, 1); @@ -182,7 +182,7 @@ void MAIN { reconfig_data_format(cb_var, cb_eps); pack_reconfig_data_format(cb_recip_sqrt_var); - add_tiles_init(); + add_tiles_init(cb_var, cb_eps); ACQ(); add_tiles(cb_var, cb_eps, 0, 0, 0); sqrt_tile_init(); diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp index 3ad8803a6ed..96ba030b47e 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax.cpp @@ -132,7 +132,7 @@ void MAIN { #endif #ifdef CAUSAL_MASK - add_tiles_init(); + add_tiles_init(cb_scale_mask, cb_fused_attn); #else add_bcast_rows_init_short(cb_scale_mask, cb_fused_attn); #endif diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp index ec44fa45c15..1dbac078a93 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/device/kernels/compute/softmax_sharded.cpp @@ -124,7 +124,7 @@ void MAIN { index_subblock_w_offset = 0; #ifdef CAUSAL_MASK - add_tiles_init(); + add_tiles_init(cb_scale_mask, cb_fused_attn); #else add_bcast_rows_init_short(cb_scale_mask, cb_fused_attn); #endif diff --git a/ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/max_pool_multi_core.cpp b/ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/max_pool_multi_core.cpp index 9600f53e03c..5fafc9cb122 100644 --- a/ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/max_pool_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/pool/generic/device/kernels/compute/max_pool_multi_core.cpp @@ -78,7 +78,7 @@ void MAIN { constexpr uint32_t partial_iter_output_tiles = in_ntiles_c % MAX_TILES_PER_REDUCTION == 0 ? max_tiles_per_iter : in_ntiles_c % MAX_TILES_PER_REDUCTION; tilizeA_B_reduce_init(in_cb_id, in_scalar_cb_id, max_tiles_per_iter, out_cb_id, num_faces_in_tile, window_size_hw); - pack_untilize_dst_init_short(out_cb_id, num_out_rows, num_faces_in_tile); + pack_untilize_dst_init_short(out_cb_id, num_out_rows, num_faces_in_tile); cb_wait_front(in_scalar_cb_id, 1); for (uint32_t i = 0; i < nsticks_per_core; ++i) { diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp index 8ec9ec529f4..f7c5cbcb7b0 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_program_factory_multicore.cpp @@ -241,7 +241,7 @@ operation::ProgramWithCallbacks upsample_multi_core( : shard_spec.orientation; ShardSpec config_shard_spec(input.shard_spec().value().grid, shard_shape, config_tensor_shard_orientation); MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1_SMALL, config_shard_spec}; - auto config_tensor_device = config_tensor.to(device, memory_config); + auto config_tensor_device = config_tensor.to_device(device, memory_config); tt::DataFormat config_df = tt::DataFormat::RawUInt16; auto config_buffer = config_tensor_device.device_buffer(); diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp b/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp index 0a456b073ea..60f9415a705 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp @@ -90,7 +90,7 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles produced reconfig_data_format(in0_cb, in1_cb); - mul_tiles_init(); + mul_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_all.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_all.cpp index f507ad5c341..09f03e26ef1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_all.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_all.cpp @@ -41,7 +41,7 @@ void MAIN { tile_regs_release(); } else { tile_regs_acquire(); - mul_tiles_init(); + mul_tiles_init(tt::CBIndex::c_0, tt::CBIndex::c_24); mul_tiles(tt::CBIndex::c_0, tt::CBIndex::c_24, 0, 0, 0); tile_regs_commit(); tile_regs_wait(); diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_nc.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_nc.cpp index a5c0047bad1..549300102d3 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_nc.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/kernels/compute/prod_nc.cpp @@ -21,7 +21,7 @@ void MAIN { constexpr uint32_t dst1 = 1; constexpr uint32_t first_tile = 0; - binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1); + binary_op_init_common(tt::CBIndex::c_0, tt::CBIndex::c_1, tt::CBIndex::c_16); cb_wait_front(cb_in1, onetile); for (uint32_t i = 0; i < num_output_tiles; i++) { @@ -36,7 +36,7 @@ void MAIN { } tile_regs_acquire(); - mul_tiles_init(); + mul_tiles_init(cb_in0, cb_add); mul_tiles(cb_in0, cb_add, first_tile, first_tile, dst0); tile_regs_commit(); diff --git a/ttnn/cpp/ttnn/operations/reduction/sampling/device/kernels/compute/sampling.cpp b/ttnn/cpp/ttnn/operations/reduction/sampling/device/kernels/compute/sampling.cpp index 1e1e9380585..689ba054883 100644 --- a/ttnn/cpp/ttnn/operations/reduction/sampling/device/kernels/compute/sampling.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/sampling/device/kernels/compute/sampling.cpp @@ -90,7 +90,7 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles produced reconfig_data_format(in0_cb, in1_cb); - add_tiles_init(); + add_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { diff --git a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp index 6997f3b59c8..b53e4ea806b 100644 --- a/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp +++ b/ttnn/cpp/ttnn/operations/sliding_window/sliding_window.cpp @@ -689,7 +689,7 @@ Tensor move_config_tensor_to_device( : ShardOrientation::ROW_MAJOR; ShardSpec shard_spec(p_config.grid, shard_shape, config_shard_orientation); MemoryConfig memory_config{TensorMemoryLayout::HEIGHT_SHARDED, BufferType::L1_SMALL, shard_spec}; - return config_tensor.to(device, memory_config); + return config_tensor.to_device(device, memory_config); } std::string SlidingWindowConfig::to_string() const { diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/compute_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/compute_common.hpp index b0ab8f9a4b5..17f6496b7b5 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/compute_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/compute_common.hpp @@ -192,7 +192,7 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles consumed - add_tiles_init(); + add_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { @@ -213,7 +213,7 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles produced - mul_tiles_init(); + mul_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { @@ -232,7 +232,7 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n // Postcondition: out_cb has num_tiles produced // Postcondition: in0_cb and in1_cb has num_tiles produced - sub_tiles_init(); + sub_tiles_init(in0_cb, in1_cb); exp_tile_init(); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp index 716de389108..df6eeb8f45e 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/compute_common.hpp @@ -217,7 +217,7 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles consumed - add_tiles_init(); + add_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { @@ -239,7 +239,7 @@ void add_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t num_t // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles consumed - add_tiles_init(); + add_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); cb_reserve_back(out_cb, num_tiles); @@ -260,7 +260,7 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { // Postcondition: in0_cb has num_tiles produced // Postcondition: in1_cb has num_tiles produced - mul_tiles_init(); + mul_tiles_init(in0_cb, in1_cb); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { @@ -278,7 +278,7 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n // Precondition: in0_cb and in1_cb have num_tiles produced // Postcondition: out_cb has num_tiles produced // Postcondition: in0_cb and in1_cb has num_tiles produced - sub_tiles_init(); + sub_tiles_init(in0_cb, in1_cb); exp_tile_init(); cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); diff --git a/ttnn/cpp/ttnn/tensor/CMakeLists.txt b/ttnn/cpp/ttnn/tensor/CMakeLists.txt index 5d03e12be5d..417c64b8580 100644 --- a/ttnn/cpp/ttnn/tensor/CMakeLists.txt +++ b/ttnn/cpp/ttnn/tensor/CMakeLists.txt @@ -7,7 +7,6 @@ set(TENSOR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/tensor_spec.cpp ${CMAKE_CURRENT_SOURCE_DIR}/tensor_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/serialization.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/shape/shape_base.cpp ${CMAKE_CURRENT_SOURCE_DIR}/shape/shape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/alignment.cpp ${CMAKE_CURRENT_SOURCE_DIR}/layout/page_config.cpp diff --git a/ttnn/cpp/ttnn/tensor/layout/alignment.cpp b/ttnn/cpp/ttnn/tensor/layout/alignment.cpp index e8e55c81a83..1b65a810d10 100644 --- a/ttnn/cpp/ttnn/tensor/layout/alignment.cpp +++ b/ttnn/cpp/ttnn/tensor/layout/alignment.cpp @@ -10,7 +10,7 @@ namespace tt::tt_metal { bool Alignment::operator==(const Alignment& other) const = default; -bool Alignment::operator==(const SmallVector& other) const { return this->value_ == other; } +bool Alignment::operator==(const tt::stl::SmallVector& other) const { return this->value_ == other; } std::ostream& operator<<(std::ostream& os, const tt::tt_metal::Alignment& alignment) { os << "Alignment(["; diff --git a/ttnn/cpp/ttnn/tensor/layout/alignment.hpp b/ttnn/cpp/ttnn/tensor/layout/alignment.hpp index 22486d0f316..fc50427cfae 100644 --- a/ttnn/cpp/ttnn/tensor/layout/alignment.hpp +++ b/ttnn/cpp/ttnn/tensor/layout/alignment.hpp @@ -4,8 +4,8 @@ #pragma once -#include "ttnn/tensor/shape/shape_base.hpp" -#include "ttnn/tensor/shape/small_vector.hpp" +#include +#include namespace tt::tt_metal { @@ -26,7 +26,7 @@ class Alignment final : protected ShapeBase { } bool operator==(const Alignment& other) const; - bool operator==(const SmallVector& other) const; + bool operator==(const tt::stl::SmallVector& other) const; // Needed for reflect / fmt static constexpr auto attribute_names = std::forward_as_tuple("value"); diff --git a/ttnn/cpp/ttnn/tensor/serialization.cpp b/ttnn/cpp/ttnn/tensor/serialization.cpp index c8bd6a7aa6b..455c0b90126 100644 --- a/ttnn/cpp/ttnn/tensor/serialization.cpp +++ b/ttnn/cpp/ttnn/tensor/serialization.cpp @@ -353,7 +353,7 @@ Tensor load_tensor_helper(const std::string& file_name, T device) { TensorLayout::fromPaddedShape( data_type, layout, MemoryConfig{}, shape.logical_shape(), shape.padded_shape()))); if (device != nullptr) { - tensor = tensor.to(device, memory_config); + tensor = tensor.to_device(device, memory_config); } else if (has_memory_config) { tt::log_warning("Memory config is ignored when loading the tensor because device is not provided"); } @@ -377,7 +377,7 @@ Tensor load_tensor_helper(const std::string& file_name, T device) { TensorLayout::fromPaddedShape( data_type, layout, MemoryConfig{}, shape.logical_shape(), shape.padded_shape()))); if (device != nullptr) { - tensor = tensor.to(device); + tensor = tensor.to_device(device); } return tensor; } diff --git a/ttnn/cpp/ttnn/tensor/shape/shape.cpp b/ttnn/cpp/ttnn/tensor/shape/shape.cpp index f707020b8fb..24475784172 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape.cpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape.cpp @@ -6,14 +6,15 @@ #include #include -#include "ttnn/tensor/shape/small_vector.hpp" + #include +#include namespace tt::tt_metal { bool Shape::operator==(const Shape& other) const = default; -bool Shape::operator==(const SmallVector& other) const { return this->value_ == other; } +bool Shape::operator==(const tt::stl::SmallVector& other) const { return this->value_ == other; } size_t Shape::rank() const { return this->size(); } @@ -29,7 +30,7 @@ std::array Shape::to_array_4D() const { } Shape Shape::to_rank(size_t new_rank) const { - SmallVector new_shape(new_rank, 1); + tt::stl::SmallVector new_shape(new_rank, 1); int cur_idx = static_cast(rank()) - 1; int new_idx = static_cast(new_rank) - 1; diff --git a/ttnn/cpp/ttnn/tensor/shape/shape.hpp b/ttnn/cpp/ttnn/tensor/shape/shape.hpp index b3c0034349c..83ddc01a422 100644 --- a/ttnn/cpp/ttnn/tensor/shape/shape.hpp +++ b/ttnn/cpp/ttnn/tensor/shape/shape.hpp @@ -4,7 +4,7 @@ #pragma once -#include "shape_base.hpp" +#include namespace tt::tt_metal { diff --git a/ttnn/cpp/ttnn/tensor/storage.cpp b/ttnn/cpp/ttnn/tensor/storage.cpp index ad385113ed8..e86cc45a2d5 100644 --- a/ttnn/cpp/ttnn/tensor/storage.cpp +++ b/ttnn/cpp/ttnn/tensor/storage.cpp @@ -16,4 +16,31 @@ std::vector> MultiDeviceStorage::get_buffers() const { return buf_vec; } +MultiDeviceStorage::MultiDeviceStorage( + const std::shared_ptr& mesh_buffer_, const TensorSpec& tensor_spec) : + strategy(ReplicateTensor{}), + mesh_buffer(mesh_buffer_) // +{ + // TODO: #17215 - In the long term, this code won't exist: no interactions will be made with individual Buffers, and + // instead the APIs will use MeshBuffer directly. MeshBuffer will also guarantee that all shards have the same + // tensor spec. + // + // For now, this code ensures MeshBuffer backed tensors are compatible with the rest of the ops infra. + const auto [num_rows, num_cols] = mesh_buffer->device()->shape(); + + ordered_device_ids.reserve(num_rows * num_cols); + buffers.reserve(num_rows * num_cols); + specs.reserve(num_rows * num_cols); + + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + auto buffer = mesh_buffer->get_device_buffer(distributed::Coordinate{row, col}); + const int device_id = buffer->device()->id(); + ordered_device_ids.push_back(device_id); + buffers.emplace(device_id, std::move(buffer)); + specs.emplace(device_id, tensor_spec); + } + } +} + } // namespace tt::tt_metal diff --git a/ttnn/cpp/ttnn/tensor/storage.hpp b/ttnn/cpp/ttnn/tensor/storage.hpp index 16f3143edae..ebb7ced0226 100644 --- a/ttnn/cpp/ttnn/tensor/storage.hpp +++ b/ttnn/cpp/ttnn/tensor/storage.hpp @@ -4,6 +4,7 @@ #pragma once +#include #include "ttnn/tensor/types.hpp" #include "ttnn/tensor/tensor_spec.hpp" @@ -243,6 +244,7 @@ struct MultiDeviceStorage { swap(first.mesh_buffer, second.mesh_buffer); } + // Constructs a multi-device tensor backed by a collection of heterogeneous single-device buffers. MultiDeviceStorage( DistributedTensorConfig strategy_, std::vector ordered_device_ids_, @@ -255,6 +257,9 @@ struct MultiDeviceStorage { specs(std::move(specs_)), mesh_buffer(std::move(mesh_buffer_)) {} + // Constructs a replicated multi-device tensor backed by mesh buffer. + MultiDeviceStorage(const std::shared_ptr& mesh_buffer_, const TensorSpec& tensor_spec); + MultiDeviceStorage(MultiDeviceStorage&& other) { swap(*this, other); } MultiDeviceStorage(const MultiDeviceStorage& other) { @@ -378,6 +383,9 @@ struct MultiDeviceStorage { using Storage = std::variant; +template +concept OwnedOrBorrowedStorage = std::is_same_v || std::is_same_v; + template constexpr void raise_unsupported_storage() { static_assert(tt::stl::concepts::always_false_v, "Unsupported Storage"); diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index ed6e2da465c..dd21761699d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -43,7 +43,7 @@ Tensor create_owned_tensor_from_row_major_data( Tensor output(OwnedStorage{owned_buffer::create(std::move(physical_data))}, spec); if (device.has_value()) { - output = output.to(device->get_devices(), spec.memory_config()); + output = output.to_device(device->get_devices(), spec.memory_config()); } return output; @@ -620,7 +620,7 @@ Tensor Tensor::from_span( Tensor tensor(OwnedStorage{owned_buffer::create(std::move(packed_block_floats))}, spec); if (device.has_value()) { - tensor = tensor.to(device->get_devices(), spec.memory_config()); + tensor = tensor.to_device(device->get_devices(), spec.memory_config()); } return tensor; } @@ -674,8 +674,8 @@ template <> std::vector Tensor::to_vector() const { Tensor cpu_tensor = this->cpu(); switch (cpu_tensor.get_dtype()) { - case DataType::BFLOAT16: return unpad_tensor_to_vec(cpu_tensor.to(Layout::ROW_MAJOR)); - case DataType::FLOAT32: return unpad_tensor_to_vec(cpu_tensor.to(Layout::ROW_MAJOR)); + case DataType::BFLOAT16: return unpad_tensor_to_vec(cpu_tensor.to_layout(Layout::ROW_MAJOR)); + case DataType::FLOAT32: return unpad_tensor_to_vec(cpu_tensor.to_layout(Layout::ROW_MAJOR)); case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: { const auto& tile = cpu_tensor.get_tensor_spec().tile(); @@ -698,7 +698,7 @@ std::vector Tensor::to_vector() const { template std::vector Tensor::to_vector() const { - auto cpu_tensor = this->cpu().to(Layout::ROW_MAJOR); + auto cpu_tensor = this->cpu().to_layout(Layout::ROW_MAJOR); TT_FATAL( cpu_tensor.get_dtype() == convert_to_data_type(), "Unsupported data type for to_vector: got {}, expected: {}", @@ -735,17 +735,17 @@ template std::vector Tensor::to_vector() const; template std::vector Tensor::to_vector() const; template std::vector Tensor::to_vector() const; -Tensor Tensor::to(IDevice* target_device, const MemoryConfig& mem_config, uint8_t cq_id) const { - return tensor_ops::tensor_to(*this, target_device, mem_config, cq_id); +Tensor Tensor::to_device(IDevice* target_device, const MemoryConfig& mem_config, uint8_t cq_id) const { + return tensor_ops::tensor_to_device(*this, target_device, mem_config, cq_id); } -Tensor Tensor::to(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config, uint8_t cq_id) const { +Tensor Tensor::to_device(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config, uint8_t cq_id) const { std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); - return tensor_ops::tensor_to(*this, workers_to_use, mem_config, cq_id); + return tensor_ops::tensor_to_device(*this, workers_to_use, mem_config, cq_id); } -Tensor Tensor::to(const std::vector& workers, const MemoryConfig& mem_config, uint8_t cq_id) const { - return tensor_ops::tensor_to(*this, workers, mem_config, cq_id); +Tensor Tensor::to_device(const std::vector& workers, const MemoryConfig& mem_config, uint8_t cq_id) const { + return tensor_ops::tensor_to_device(*this, workers, mem_config, cq_id); } Tensor Tensor::cpu(bool blocking, uint8_t cq_id) const { return tensor_ops::tensor_cpu(*this, blocking, cq_id); } @@ -761,12 +761,12 @@ Tensor Tensor::extract_shard(const uint32_t& core_id) const { return tensor_impl::extract_shard_wrapper(*this, core_id); } -Tensor Tensor::to(Layout target_layout, IDevice* worker) const { - return tensor_ops::tensor_to(*this, target_layout, worker); +Tensor Tensor::to_layout(Layout target_layout, IDevice* worker) const { + return tensor_ops::tensor_to_layout(*this, target_layout, worker); } -Tensor Tensor::to(Layout target_layout, distributed::MeshDevice* mesh_device) const { - return tensor_ops::tensor_to(*this, target_layout, mesh_device); +Tensor Tensor::to_layout(Layout target_layout, distributed::MeshDevice* mesh_device) const { + return tensor_ops::tensor_to_layout(*this, target_layout, mesh_device); } const std::string Tensor::write_to_string() const { return tensor_impl::to_string_wrapper(*this); } @@ -1016,29 +1016,7 @@ Tensor allocate_tensor_on_mesh(const TensorSpec& tensor_spec, distributed::MeshD TT_FATAL( tt::tt_metal::detail::InMainThread(), "Allocation of a tensor on mesh must be called from the main thread"); auto mesh_buffer = tensor_impl::allocate_mesh_buffer_on_device(mesh_device, tensor_spec); - - const auto [num_rows, num_cols] = mesh_device->shape(); - std::vector ordered_device_ids; - std::unordered_map> buffers; - std::unordered_map specs; - - ordered_device_ids.reserve(num_rows * num_cols); - buffers.reserve(num_rows * num_cols); - specs.reserve(num_rows * num_cols); - - for (int row = 0; row < num_rows; ++row) { - for (int col = 0; col < num_cols; ++col) { - auto buffer = mesh_buffer->get_device_buffer(distributed::Coordinate{row, col}); - const int device_id = buffer->device()->id(); - ordered_device_ids.push_back(device_id); - buffers.emplace(device_id, std::move(buffer)); - specs.emplace(device_id, tensor_spec); - } - } - - MultiDeviceStorage multi_device_storage( - ReplicateTensor{}, std::move(ordered_device_ids), std::move(buffers), std::move(specs), std::move(mesh_buffer)); - + MultiDeviceStorage multi_device_storage(std::move(mesh_buffer), tensor_spec); return Tensor(std::move(multi_device_storage), tensor_spec); } diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 7aded8ad795..79f4adcdd26 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -174,24 +174,24 @@ class Tensor { template std::vector to_vector() const; - Tensor to( + Tensor to_device( IDevice* target_device, const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, uint8_t cq_id = ttnn::DefaultQueueId) const; - Tensor to( + Tensor to_device( distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, uint8_t cq_id = ttnn::DefaultQueueId) const; - Tensor to( + Tensor to_device( const std::vector& workers, const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, uint8_t cq_id = ttnn::DefaultQueueId) const; - Tensor to(Layout target_layout, IDevice* worker = nullptr) const; + Tensor to_layout(Layout target_layout, IDevice* worker = nullptr) const; - Tensor to(Layout target_layout, distributed::MeshDevice* mesh_device) const; + Tensor to_layout(Layout target_layout, distributed::MeshDevice* mesh_device) const; Tensor pad(const ttnn::Shape& output_padded_shape, const ttnn::Shape& input_tensor_start, float pad_value) const; diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 1a45fc43960..da7d5e20e28 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -6,6 +6,11 @@ #include #include "tt-metalium/mesh_buffer.hpp" +#include "tt-metalium/mesh_device.hpp" +#include "tt-metalium/mesh_command_queue.hpp" +#include "tt-metalium/overloaded.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/tensor/storage.hpp" #include "ttnn/tensor/tensor_impl_wrapper.hpp" #include "ttnn/tensor/layout/tensor_layout.hpp" #include "ttnn/tensor/types.hpp" @@ -75,6 +80,9 @@ std::shared_ptr allocate_mesh_buffer_on_device( .buffer_layout = memory_config.memory_layout, .shard_parameters = tensor_spec.compute_shard_spec_buffer(), }; + + // Use replicated buffer, which supports both working with individual shards and replicating data across all shards. + // This is required for the time being, as TTNN has rich multi-device sharding implementation. const distributed::ReplicatedBufferConfig replicated_buffer_config{ .size = tensor_spec.compute_packed_buffer_size_bytes(), }; @@ -567,6 +575,66 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { return to_host(tensor, blocking, cq_id); } +template +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { + TT_FATAL(ttnn::distributed::is_mesh_buffer_tensor(tensor), "Tensor is not a mesh buffer tensor!"); + TT_FATAL(tt::tt_metal::detail::InMainThread(), "to_host_mesh_tensor must be called from the main thread"); + const auto& storage = std::get(tensor.get_storage()); + const auto& mesh_buffer = storage.mesh_buffer; + ttnn::MeshDevice* device = mesh_buffer->device(); + distributed::MeshCommandQueue& mesh_cq = device->mesh_command_queue(); + const auto [num_rows, num_cols] = device->shape(); + const auto num_buffers = storage.buffers.size(); + + std::vector shard_data_transfers; + std::vector specs; + std::vector buffers; + specs.reserve(num_buffers); + buffers.reserve(num_buffers); + shard_data_transfers.reserve(num_buffers); + distributed::Coordinate shard_coord = {0, 0}; + for (int id : storage.ordered_device_ids) { + std::vector host_buffer; + const auto& shard_tensor_spec = storage.specs.at(id); + const auto tensor_size_bytes = shard_tensor_spec.compute_packed_buffer_size_bytes(); + host_buffer.resize(tensor_size_bytes / sizeof(T)); + specs.push_back(shard_tensor_spec); + buffers.push_back(owned_buffer::create(std::move(host_buffer))); + + shard_data_transfers.push_back(distributed::MeshCommandQueue::ShardDataTransfer{ + .shard_coord = shard_coord, + .host_data = std::visit([](auto& b) { return b.data(); }, buffers.back()), + .region = BufferRegion(0, tensor_size_bytes)}); + + if (++shard_coord.col == num_cols) { + shard_coord.col = 0; + ++shard_coord.row; + } + } + + mesh_cq.enqueue_read_shards(shard_data_transfers, mesh_buffer, /*blocking=*/true); + + MultiDeviceHostStorage host_storage(storage.strategy, std::move(buffers), std::move(specs)); + return Tensor(std::move(host_storage), tensor.get_tensor_spec()); +} + +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); +template Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking); + +template <> +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { + return to_host_mesh_tensor(tensor, blocking); +} + +template <> +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking) { + return to_host_mesh_tensor(tensor, blocking); +} + // ====================================================================================== // .to_device() details // ====================================================================================== @@ -613,9 +681,8 @@ template std::shared_ptr to_device_buffer( const Storage& storage, IDevice* device, const TensorSpec& tensor_spec, uint8_t cq_id) { return std::visit( - [&device, &tensor_spec, cq_id](auto&& storage) -> std::shared_ptr { - using StorageType = std::decay_t; - if constexpr (std::is_same_v or std::is_same_v) { + tt::stl::overloaded{ + [&device, &tensor_spec, cq_id](const StorageType& storage) { auto data_to_write = host_buffer::get_as(storage.buffer); auto expected_packed_buffer_size_bytes = tensor_spec.compute_packed_buffer_size_bytes(); auto input_size_bytes = data_to_write.size() * sizeof(T); @@ -625,16 +692,11 @@ std::shared_ptr to_device_buffer( input_size_bytes, expected_packed_buffer_size_bytes); return initialize_data_on_device(data_to_write, device, tensor_spec, cq_id); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage doesn't support to_device_buffer"); - } else if constexpr (std::is_same_v) { - TT_THROW("MultiHostStorage storage doesn't support to_device_buffer"); - } else if constexpr (std::is_same_v) { - TT_THROW("MultiDeviceStorage doesn't support to_device_buffer"); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) { + TT_THROW("Unexpected storage type {}", tt::stl::get_type_name(s)); + return std::shared_ptr(); + }}, storage); } @@ -645,9 +707,6 @@ std::shared_ptr to_device_buffer( template Tensor to_device(const Tensor& tensor, IDevice* target_device, const MemoryConfig& memory_config, uint8_t cq_id) { TT_FATAL(tensor.storage_type() != StorageType::DEVICE, "Tensor is already on device!"); - if (tensor.storage_type() == StorageType::OWNED) { - TT_FATAL(tensor.is_allocated(), "Need host buffer on device to exist to copy data to device!"); - } TT_FATAL(target_device != nullptr, "Need target device in order to move tensor to device!"); TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device"); @@ -682,6 +741,141 @@ Tensor to_device( return to_device(tensor, target_device, memory_config, cq_id); } +template +MultiDeviceStorage replicate_to_mesh_buffer( + const StorageType& storage, + distributed::MeshDevice* mesh_device, + const std::shared_ptr& mesh_buffer, + const TensorSpec& tensor_spec) { + auto data_to_write = host_buffer::get_as(storage.buffer); + const auto expected_packed_buffer_size_bytes = tensor_spec.compute_packed_buffer_size_bytes(); + const auto input_size_bytes = data_to_write.size() * sizeof(T); + TT_FATAL( + input_size_bytes == expected_packed_buffer_size_bytes, + "Host data with total size {}B does not match expected size {}B of device buffer!", + input_size_bytes, + expected_packed_buffer_size_bytes); + + mesh_device->mesh_command_queue().enqueue_write_mesh_buffer(mesh_buffer, data_to_write.data(), /*blocking=*/false); + return MultiDeviceStorage(mesh_buffer, tensor_spec); +} + +template +MultiDeviceStorage shard_to_mesh_buffer( + const MultiDeviceHostStorage& storage, + distributed::MeshDevice* mesh_device, + const std::shared_ptr& mesh_buffer, + const TensorSpec& tensor_spec) { + std::vector ordered_device_ids; + std::unordered_map> buffers; + std::unordered_map specs; + ordered_device_ids.reserve(storage.buffers.size()); + buffers.reserve(storage.buffers.size()); + specs.reserve(storage.buffers.size()); + + const auto [num_rows, num_cols] = mesh_device->shape(); + TT_FATAL( + storage.buffers.size() <= mesh_device->num_devices(), + "Number of host buffers {} exceeds the number of shards {}", + storage.buffers.size(), + mesh_device->num_devices()); + + std::vector shard_data_transfers; + shard_data_transfers.reserve(storage.buffers.size()); + distributed::Coordinate shard_coord = {0, 0}; + for (int i = 0; i < storage.buffers.size(); i++) { + TensorSpec shard_tensor_spec( + storage.specs[i].logical_shape(), + storage.specs[i].tensor_layout().with_memory_config(tensor_spec.memory_config())); + const auto& shard_host_buffer = storage.buffers[i]; + + const auto& shard_buffer = mesh_buffer->get_device_buffer(shard_coord); + ordered_device_ids.push_back(shard_buffer->device()->id()); + buffers.insert({shard_buffer->device()->id(), shard_buffer}); + specs.insert({shard_buffer->device()->id(), shard_tensor_spec}); + + auto data_to_write = host_buffer::get_as(shard_host_buffer); + const auto expected_packed_buffer_size_bytes = shard_tensor_spec.compute_packed_buffer_size_bytes(); + const auto input_size_bytes = data_to_write.size() * sizeof(T); + TT_FATAL( + input_size_bytes == expected_packed_buffer_size_bytes, + "Host data with total size {}B does not match expected size {}B of device buffer!", + input_size_bytes, + expected_packed_buffer_size_bytes); + TT_FATAL( + expected_packed_buffer_size_bytes <= tensor_spec.compute_packed_buffer_size_bytes(), + "Shard tensor size exceeds the global tensor size!"); + shard_data_transfers.push_back(distributed::MeshCommandQueue::ShardDataTransfer{ + .shard_coord = shard_coord, + .host_data = data_to_write.data(), + .region = BufferRegion(0, input_size_bytes)}); + if (++shard_coord.col == num_cols) { + shard_coord.col = 0; + ++shard_coord.row; + } + } + + mesh_device->mesh_command_queue().enqueue_write_shards(mesh_buffer, shard_data_transfers, /*blocking=*/false); + + return MultiDeviceStorage( + storage.strategy, std::move(ordered_device_ids), std::move(buffers), std::move(specs), mesh_buffer); +} + +template +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* mesh_device, const MemoryConfig& memory_config) { + TT_FATAL(tt::tt_metal::detail::InMainThread(), "to_device_mesh_tensor must be called from the main thread"); + TT_FATAL(tensor.storage_type() != StorageType::MULTI_DEVICE, "Tensor is already on device!"); + TT_FATAL(mesh_device != nullptr, "Need target device in order to move tensor to device!"); + TT_FATAL(tensor.is_allocated(), "Need data to exist in order to move it to device"); + + TensorSpec tensor_spec( + tensor.get_logical_shape(), tensor.get_tensor_spec().tensor_layout().with_memory_config(memory_config)); + + auto mesh_buffer = allocate_mesh_buffer_on_device(mesh_device, tensor_spec); + MultiDeviceStorage mesh_storage = std::visit( + tt::stl::overloaded{ + [&mesh_device, &mesh_buffer, &tensor_spec](const StorageType& storage) { + // Replicate data across devices in a mesh. + return replicate_to_mesh_buffer(storage, mesh_device, mesh_buffer, tensor_spec); + }, + [&mesh_device, &mesh_buffer, &tensor_spec](const MultiDeviceHostStorage& storage) { + // Shard multi device host shards across devices in a mesh.. + return shard_to_mesh_buffer(storage, mesh_device, mesh_buffer, tensor_spec); + }, + [](const auto& s) -> MultiDeviceStorage { + TT_THROW("Unexpected storage type {}", tt::stl::get_type_name(s)); + }}, + tensor.get_storage()); + + return Tensor(std::move(mesh_storage), tensor_spec); +} + +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); +template Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config); + +template <> +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config) { + return to_device_mesh_tensor(tensor, target_device, memory_config); +} + +template <> +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* target_device, const MemoryConfig& memory_config) { + return to_device_mesh_tensor(tensor, target_device, memory_config); +} + // ====================================================================================== // Helpers for converting between logical <-> physical data with full tensor spec // ====================================================================================== @@ -909,18 +1103,20 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { } }; + using RetType = std::variant; auto output_storage = std::visit( - [&convert, target_layout](auto&& storage) -> std::variant { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { + tt::stl::overloaded{ + [&convert, target_layout](const OwnedStorage& storage) -> RetType { const auto input_data = owned_buffer::get_as(storage.buffer); auto output_buffer = owned_buffer::create(std::move(convert(input_data))); return OwnedStorage{output_buffer}; - } else if constexpr (std::is_same_v) { + }, + [&convert, target_layout](const BorrowedStorage& storage) -> RetType { const auto input_data = borrowed_buffer::get_as(storage.buffer); auto output_buffer = owned_buffer::create(std::move(convert(input_data))); return OwnedStorage{output_buffer}; - } else if constexpr (std::is_same_v) { + }, + [&convert, target_layout](const MultiDeviceHostStorage& storage) -> RetType { std::vector output_buffers; std::vector output_specs; for (int i = 0; i < storage.num_buffers(); i++) { @@ -938,14 +1134,8 @@ Tensor to_layout(const Tensor& tensor, Layout target_layout) { prev_spec.padded_shape()))); } return MultiDeviceHostStorage{storage.strategy, output_buffers, output_specs}; - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("On-device layout conversion for tensor with MultiDeviceStorage is not supported."); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) -> RetType { TT_THROW("Unsupported storage type {}", tt::stl::get_type_name(s)); }}, tensor.get_storage()); return std::visit( @@ -1078,24 +1268,14 @@ Tensor pad( }; auto output_buffer = std::visit( - [&pad](auto&& storage) -> owned_buffer::Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - const auto input_data = owned_buffer::get_as(storage.buffer); + tt::stl::overloaded{ + [&pad](const StorageType& storage) { + const auto input_data = host_buffer::get_as(storage.buffer); return pad(input_data); - } else if constexpr (std::is_same_v) { - const auto input_data = borrowed_buffer::get_as(storage.buffer); - return pad(input_data); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) -> owned_buffer::Buffer { + TT_THROW("Unsupported storage type {}", tt::stl::get_type_name(s)); + }}, tensor.get_storage()); return Tensor( OwnedStorage{output_buffer}, @@ -1196,24 +1376,14 @@ Tensor unpad(const Tensor& tensor, const ttnn::Shape& output_tensor_start, const }; auto output_buffer = std::visit( - [&unpad](auto&& storage) -> owned_buffer::Buffer { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - const auto input_data = owned_buffer::get_as(storage.buffer); - return unpad(input_data); - } else if constexpr (std::is_same_v) { - const auto input_data = borrowed_buffer::get_as(storage.buffer); + tt::stl::overloaded{ + [&unpad](const StorageType& storage) { + const auto input_data = host_buffer::get_as(storage.buffer); return unpad(input_data); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else if constexpr (std::is_same_v) { - TT_THROW("Device storage isn't supported"); - } else { - raise_unsupported_storage(); - } - }, + }, + [](const auto& s) -> owned_buffer::Buffer { + TT_THROW("Unsupported storage type {}", tt::stl::get_type_name(s)); + }}, tensor.get_storage()); return Tensor( OwnedStorage{output_buffer}, diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 2602e0e4b2c..2a4654b8aac 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -8,6 +8,7 @@ #include #include +#include "tt-metalium/mesh_device.hpp" #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_utils.hpp" @@ -173,23 +174,27 @@ std::shared_ptr allocate_mesh_buffer_on_device( distributed::MeshDevice* mesh_device, const TensorSpec& tensor_spec); template -inline void read_data_from_device_buffer( +void read_data_from_device_buffer( CommandQueue& cq, std::shared_ptr device_buffer, void* host_buffer_data, bool blocking) { EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); } template -inline void read_data_from_device_buffer(std::shared_ptr device_buffer, std::vector& host_buffer) { +void read_data_from_device_buffer(std::shared_ptr device_buffer, std::vector& host_buffer) { ::tt::tt_metal::detail::ReadFromBuffer(device_buffer, host_buffer); } // ====================================================================================== -// .to() +// .to_host() and .to_device() // ====================================================================================== template Tensor to_host(const Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId); +// TODO: #17215 - This will eventually subsume `to_host`, when "mesh buffer" backed tensors become the default. +template +Tensor to_host_mesh_tensor(const Tensor& tensor, bool blocking = true); + template Tensor to_device( const Tensor& tensor, @@ -197,6 +202,15 @@ Tensor to_device( const MemoryConfig& memory_config, uint8_t cq_id = ttnn::DefaultQueueId); +// TODO: #17215 - This will eventually subsume `to_device`, when "mesh buffer" backed tensors become the default. +template +Tensor to_device_mesh_tensor( + const Tensor& tensor, distributed::MeshDevice* mesh_device, const MemoryConfig& memory_config); + +// ====================================================================================== +// .to_layout() +// ====================================================================================== + template Tensor to_layout(const Tensor& tensor, Layout target_layout); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp index 9cf4c810591..7bf2d8690d3 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl_wrapper.hpp @@ -38,8 +38,10 @@ inline size_t packed_buffer_size_bytes_wrapper(DataType dtype, size_t volume_unp } WRAP_FUNCTION(to_host) +WRAP_FUNCTION(to_host_mesh_tensor) WRAP_FUNCTION(extract_shard) WRAP_FUNCTION(to_device) +WRAP_FUNCTION(to_device_mesh_tensor) WRAP_FUNCTION(to_layout) WRAP_FUNCTION(pad) WRAP_FUNCTION(unpad) diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 67d0a50633d..5896e7b6f3a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -27,7 +27,8 @@ namespace tt::tt_metal::tensor_ops { -Tensor tensor_to(const Tensor& input_tensor, IDevice* target_device, const MemoryConfig& mem_config, uint8_t cq_id) { +Tensor tensor_to_device( + const Tensor& input_tensor, IDevice* target_device, const MemoryConfig& mem_config, uint8_t cq_id) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, target_device, mem_config); // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. @@ -63,7 +64,7 @@ Tensor tensor_to(const Tensor& input_tensor, IDevice* target_device, const Memor return device_tensor; } -Tensor tensor_to( +Tensor tensor_to_device( const Tensor& input_tensor, const std::vector& workers, const MemoryConfig& mem_config, uint8_t cq_id) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, workers, mem_config); @@ -141,7 +142,7 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { return host_tensor; } -Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, IDevice* worker) { +Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, IDevice* worker) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, target_layout, worker); // Only push layout conversion to worker if running in async mode @@ -173,7 +174,7 @@ Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, IDevice* work return output; } -Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed::MeshDevice* mesh_device) { +Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, distributed::MeshDevice* mesh_device) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, target_layout, mesh_device); if (mesh_device) { diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 9c5ae143e7b..9deb78bad6f 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -20,14 +20,15 @@ class IDevice; namespace tt::tt_metal::tensor_ops { -Tensor tensor_to(const Tensor& input_tensor, IDevice* target_device, const MemoryConfig& mem_config, uint8_t cq_id); +Tensor tensor_to_device( + const Tensor& input_tensor, IDevice* target_device, const MemoryConfig& mem_config, uint8_t cq_id); -Tensor tensor_to( +Tensor tensor_to_device( const Tensor& input_tensor, const std::vector& workers, const MemoryConfig& mem_config, uint8_t cq_id); -Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, IDevice* worker); +Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, IDevice* worker); -Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed::MeshDevice* mesh_device); +Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, distributed::MeshDevice* mesh_device); Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id); diff --git a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp index f773e962db3..df97212e648 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/conversion_utils.hpp @@ -4,7 +4,8 @@ #pragma once -#include "ttnn/tensor/shape/small_vector.hpp" +#include + #include "ttnn/tensor/tensor.hpp" #include