Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Enhance checks around KIND_GPU and tensor parallelism #42

Merged
merged 10 commits into from
May 31, 2024

Conversation

rmccorm4
Copy link
Contributor

@rmccorm4 rmccorm4 commented May 24, 2024

What does the PR do?

Problem

When loading multiple instances of a vLLM model on a multi-gpu system (default behavior with KIND_GPU which is the default instance group when left unspecified), all model instances will default to the same device and can cause a CUDA OOM rather than loading a model on each GPU device assigned by the instance group settings.

This is rooted in Triton's KIND_GPU behavior making an assumption that the model is assigned only one GPU. In the future, KIND_GPU may be expanded to define a set of multiple GPUs for a single model. However, for now the recommendation is to use KIND_MODEL when a model can have multiple GPUs and use them freely (such as python models).

Solution

These changes try to account for this assumption by isolating the assigned GPU device ID when KIND_GPU is used, and otherwise raising an error and recommending usage of KIND_MODEL when the vLLM config implies that this is a multi-GPU model (such as tensor_parallel_size > 1).

Checklist

  • PR title reflects the change and is of format <commit_type>: <Title>
  • Changes are described in the pull request.
  • Related issues are referenced.
  • Populated github labels field
  • Added test plan and verified test passes.
  • Verified that the PR passes existing CI.
  • Verified copyright is correct on all changed files.
  • Added succinct git squash message before merging ref.
  • All template sections are filled out.
  • Optional: Additional screenshots for behavior/output changes with before/after.

Commit Type:

Check the conventional commit type box here and add the label to the github PR.

  • build
  • chore
  • ci
  • docs
  • feat
  • fix
  • perf
  • refactor
  • revert
  • style
  • test

Test plan

  1. Single-gpu model on a multi-gpu system with KIND_GPU (default)

Change: Success, specify device ID assigned by Triton Core (from the instance group) when initializing vLLM to avoid OOM from all instances defaulting to device 0.

I0524 19:05:41.036859 3948 model.py:170] Detected KIND_GPU model instance, explicitly setting GPU device=0 for llama-3-8b-instruct_0
I0524 19:05:41.043640 3948 model.py:170] Detected KIND_GPU model instance, explicitly setting GPU device=1 for llama-3-8b-instruct_1
  1. Multi-gpu model (tensor_parallel_size > 1) with KIND_GPU (default)

Change: Failure, with clear error to specify KIND_MODEL instead for multi-gpu models

image

  1. Multi-gpu model (tensor_parallel_size > 1) with KIND_MODEL (manually specified)

Success, and auto uses Ray Worker.

image

Caveats

  1. If there are other vLLM config fields that can imply multi-gpu other than tensor_parallel_size, we can add checks for those too.
  2. For KIND_MODEL models with tensor_parallelism == 1 and model_instance_count > 1, we run into the same issue where vLLM will try to allocate each instance on the same GPU. This could be enhanced with a similar check in this PR, or deferred to future enhancement.
  3. This PR doesn't provide an explicitly good way to support the case for multiple multi-gpu instances.
    • Take a 2-GPU model that you want 4 copies of on an 8-GPU system as an example.
    • If you specify KIND_MODEL in the config with 4 model instances, and tensor_parallel_size: 2, with 8-GPUs at your disposal, there's currently no explicit check or validation around this. I don't know how this would behave, and would likely default to how the RayWorker logic in vLLM would attempt to assign GPUs.

src/model.py Fixed Show fixed Hide fixed
@rmccorm4 rmccorm4 requested a review from Tabrizian May 24, 2024 19:18
@rmccorm4
Copy link
Contributor Author

Didn't mean for the Lora stuff to be removed - will fix that.

nnshah1
nnshah1 previously approved these changes May 24, 2024
Copy link
Contributor

@nnshah1 nnshah1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM until / if we find a different way.

Would be good to add a section on deploying on multiple GPUs TP 1 , TP > 1

But defer to others on other ways to tackle it

src/model.py Outdated Show resolved Hide resolved
@rmccorm4
Copy link
Contributor Author

Marking draft while I fix a couple things

@rmccorm4 rmccorm4 marked this pull request as draft May 24, 2024 19:36
@rmccorm4 rmccorm4 marked this pull request as draft May 24, 2024 19:36
@oandreeva-nv
Copy link
Contributor

I need LoRA back before I can approve

@rmccorm4 rmccorm4 requested review from oandreeva-nv and nnshah1 May 24, 2024 21:13
@rmccorm4 rmccorm4 marked this pull request as ready for review May 24, 2024 21:13
GuanLuo
GuanLuo previously approved these changes May 24, 2024
GuanLuo
GuanLuo previously approved these changes May 24, 2024
src/model.py Outdated
)
# NOTE: this only affects this process and it's subprocesses, not other processes.
# vLLM doesn't currently seem to expose selecting a specific device in the APIs.
os.environ["CUDA_VISIBLE_DEVICES"] = triton_device_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave my observations here for the record.

I've tried also LOCAL_RANK . It is very flaky and GPU block's calculations do not correspond to what is calculated in "GPU-isolated" case.

oandreeva-nv
oandreeva-nv previously approved these changes May 24, 2024
Copy link
Contributor

@oandreeva-nv oandreeva-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, my only ask is to try re-running tests a couple of times, to see if any flakiness is happenning.

src/model.py Show resolved Hide resolved
src/model.py Outdated Show resolved Hide resolved
nnshah1
nnshah1 previously approved these changes May 25, 2024
Copy link
Contributor

@nnshah1 nnshah1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting the fix in quickly!

@rmccorm4 rmccorm4 dismissed stale reviews from nnshah1, oandreeva-nv, and GuanLuo via 935bf92 May 29, 2024 22:18
@rmccorm4 rmccorm4 changed the title Enhance checks around KIND_GPU and tensor parallelism fix: Enhance checks around KIND_GPU and tensor parallelism May 29, 2024
@Tabrizian
Copy link
Member

Tabrizian commented May 30, 2024

I think we want to allow the model to interact with other GPUs even if KIND_GPU and a specific device is specified. The reason is that the model can be part of an ensemble pipeline and it may want to copy the tensors from other devices even though the actual execution is happening on device_id.

I think it is better to set the default context to the device id specified by Triton. For vllm, this can be achieved using torch.cuda.set_device.

@rmccorm4
Copy link
Contributor Author

@Tabrizian Sure, I'll try setting device instead. There may be less unintended consequences that way.

@rmccorm4 rmccorm4 requested a review from oandreeva-nv May 30, 2024 17:44
src/model.py Outdated Show resolved Hide resolved
function run_multi_gpu_test() {
export KIND="${1}"
export TENSOR_PARALLELISM="${2}"
export INSTANCE_COUNT="${3}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you kneed export? Looks like all usages are local

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see it now.. Should try to move server setup into py test (setup / teardown), @jbkyang-nvi had done something similar.

Copy link
Contributor Author

@rmccorm4 rmccorm4 May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you or @jbkyang-nvi have a reference for that? If not, I can probably just do all this stuff inside the pytest using the in-process python API a bit more easily if we don't need any frontend features and @oandreeva-nv doesn't mind.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just turn what we do in bash to in python (spawn process / file system manipulation etc.)

Seems like the change is reverted, sad.
triton-inference-server/server#7195 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree a common set of utils to prep/start/stop server via python+subprocess would be great. That would probably take me some time to write something good though. Can I merge these tests using bash and follow-up with this after we deal with the P0's and pipeline failures? I'll take this test as a specific example to refactor using the common util I write. @GuanLuo @oandreeva-nv

# Run unit tests
set +e
CLIENT_LOG="./vllm_multi_gpu_test--${KIND}_tp${TENSOR_PARALLELISM}_count${INSTANCE_COUNT}--client.log"
python3 $CLIENT_PY -v > $CLIENT_LOG 2>&1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running all unit tests against different settings? Is that necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only a single test right now, just lots of helpers. If I move the server/model setup into the python test like you mentioned, then the bash part can be simplified.

if int(tp) * int(instance_count) != 2:
msg = "TENSOR_PARALLELISM and INSTANCE_COUNT must have a product of 2 for this 2-GPU test"
print("Skipping Test:", msg)
self.skipTest(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this into @unittest.skipIf ? It would be easier to locate then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd have to move the tp and instance_counts to be global or passed directly to the test somehow to do this - I was trying to avoid being too fancy with these tests, but looks like I'll need to rethink them based on the comments so far.

if kind == "KIND_MODEL" and int(instance_count) > 1:
msg = "Testing multiple model instances of KIND_MODEL is not implemented at this time"
print("Skipping Test:", msg)
self.skipTest(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@rmccorm4
Copy link
Contributor Author

rmccorm4 commented May 31, 2024

Follow-up ticket for the threads I'm leaving unresolved: DLIS-6804

@rmccorm4 rmccorm4 merged commit 18a96e3 into main May 31, 2024
3 checks passed
@rmccorm4 rmccorm4 deleted the rmccormick-multi-gpu-default branch May 31, 2024 18:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants