From fae714d6bdf45ab0f6b98554d542a0f4e04a0eb9 Mon Sep 17 00:00:00 2001
From: Omar Elayan <142979319+oelayan7@users.noreply.github.com>
Date: Wed, 15 Jan 2025 21:25:29 +0200
Subject: [PATCH 1/5] [inf] Add config var to enable keeping module on host
 (#6846)

Using keep_module_on_host config var will let us control if the loaded
checkpoints to model parameters will be moved to the device or stay on
host

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
---
 deepspeed/inference/config.py             |  9 +++++++
 deepspeed/inference/engine.py             |  2 +-
 deepspeed/module_inject/auto_tp.py        | 33 +++++++++++++++--------
 deepspeed/module_inject/replace_module.py |  3 ++-
 tests/unit/inference/test_inference.py    | 21 +++++++++++++--
 5 files changed, 53 insertions(+), 15 deletions(-)

diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py
index 42ffebbc4386..6df61f7c8841 100644
--- a/deepspeed/inference/config.py
+++ b/deepspeed/inference/config.py
@@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
     values for :any:`DeepSpeedMoEConfig`.
     """
 
+    keep_module_on_host: bool = False
+    """
+    When loading checkpoints to model parameters, they are moved to the device. In very large models
+    this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
+    host and not move them directly to the device (giving an option to quantize checkpoint data before
+    moving it to the device for example).
+    Set only for models with injection policies and auto TP.
+    """
+
     quant: QuantizationConfig = {}
     """
     NOTE: only works for int8 dtype.
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py
index 131dce07d22d..be6336d02a19 100755
--- a/deepspeed/inference/engine.py
+++ b/deepspeed/inference/engine.py
@@ -169,7 +169,7 @@ def __init__(self, model, config):
         is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta'
         if is_meta_device:
             self.module.to_empty(device=device)
-        else:
+        elif not config.keep_module_on_host:
             self.module.to(device)
 
         if config.tensor_parallel.tp_size > 1:
diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py
index 66d7c2659359..8bdcf6faa053 100755
--- a/deepspeed/module_inject/auto_tp.py
+++ b/deepspeed/module_inject/auto_tp.py
@@ -17,14 +17,14 @@
 from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
 
 
-def move(tensor, device):
+def move(tensor, device, copy=True):
     if tensor.is_meta:
         return torch.empty_like(tensor, device=device)
     else:
         # Using new tensors help in freeing memory (after split for example) was done before by calling clone().
         # Using copy=True instead of clone() will help in case of cpu --> cpu.
         # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
-        return tensor.to(device, copy=True)
+        return tensor.to(device, copy=copy)
 
 
 class ReplaceWithTensorSlicing:
@@ -189,7 +189,14 @@ def load(module, state_dict, prefix, mp_group=None):
 
 class AutoTP():
 
-    def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl):
+    def __init__(self,
+                 module,
+                 all_reduce_linears,
+                 prefix,
+                 state_dict,
+                 linear_layer_setting,
+                 orig_layer_impl,
+                 keep_module_on_host=False):
         self.module = module
         self.all_reduce_linears = all_reduce_linears
         self.prefix = prefix
@@ -201,6 +208,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_
         self.orig_layer_impl = orig_layer_impl
         self.linear_policies = None
         self.conv_linear_layer = False
+        self.keep_module_on_host = keep_module_on_host
 
     def in_module_list(module, module_list):
         for item in module_list:
@@ -331,6 +339,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
     def _replace(self, child, name, conv_linear_layer):
         if getattr(child, "replaced", False) == True:
             return
+        device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
+        # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
+        # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
+        return_new_copy = not self.keep_module_on_host
         weight_shape = child.weight.shape
         mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
         # For TP layer skip, e.g., MoE gate, deepseek low rank layer skip
@@ -368,7 +380,7 @@ def _replace(self, child, name, conv_linear_layer):
             data = child.weight.data.split(get_shard_size_list(
                 weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
                                            dim=1)
-            data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
+            data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
             del data
 
             setattr(child, "replaced", True)
@@ -376,10 +388,9 @@ def _replace(self, child, name, conv_linear_layer):
                 return LmHeadLinearAllreduce(
                     torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
                     child.bias if child.bias is None else torch.nn.parameter.Parameter(
-                        move(child.bias,
-                             get_accelerator().current_device_name())), self.mp_group)
+                        move(child.bias, device_name, return_new_copy)), self.mp_group)
             return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
-                        torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group)
+                        torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
         else:
 
             # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
@@ -392,22 +403,22 @@ def _replace(self, child, name, conv_linear_layer):
                 #The copy is a regular copy, The shape of dst and src is the same
                 data_dc = move(
                     prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
-                    get_accelerator().current_device_name())
+                    device_name, return_new_copy)
 
                 bias_data_dc = None if child.bias is None else move(
                     prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
-                    get_accelerator().current_device_name())
+                    device_name, return_new_copy)
             else:
                 data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
                                                dim=1 if self.conv_linear_layer else 0)
-                data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach()
+                data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
                 del data
 
                 if child.bias is not None:
                     bias_data = child.bias.data.split(get_shard_size_list(
                         weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
                                                       dim=0)
-                    bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name())
+                    bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
                     bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
                     del bias_data
                 else:
diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py
index 00b22aac81d8..32c88549c821 100644
--- a/deepspeed/module_inject/replace_module.py
+++ b/deepspeed/module_inject/replace_module.py
@@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
         #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group)
 
         # 1. Create AutoTP object
-        _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl)
+        _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl,
+                         config.keep_module_on_host)
 
         # 2. Set the tensor parallelism config
         _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)
diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py
index 9b563523dbeb..df85ed232a2e 100644
--- a/tests/unit/inference/test_inference.py
+++ b/tests/unit/inference/test_inference.py
@@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty
 
 
 @pytest.mark.seq_inference
+@pytest.mark.parametrize('keep_module_on_host', [True, False])
 @pytest.mark.parametrize(
     "model_w_task",
     [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")],
@@ -570,6 +571,7 @@ def test(
         inf_kwargs,
         assert_fn,
         dtype,
+        keep_module_on_host,
     ):
         invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
         if invalid_test_msg:
@@ -592,13 +594,20 @@ def test(
                         framework="pt")
         bs_output = pipe(query, **inf_kwargs)
 
-        pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
+        pipe.model = deepspeed.init_inference(pipe.model,
+                                              mp_size=world_size,
+                                              dtype=dtype,
+                                              keep_module_on_host=keep_module_on_host)
         ds_output = pipe(query, **inf_kwargs)
 
         print(local_rank, "baseline", bs_output)
         print(local_rank, "deepspeed", ds_output)
         assert assert_fn(bs_output, ds_output)
 
+        if keep_module_on_host:
+            for name, param in model.named_parameters():
+                assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
+
     @pytest.mark.world_size(3)
     def test_odd_world_size(
         self,
@@ -607,6 +616,7 @@ def test_odd_world_size(
         inf_kwargs,
         assert_fn,
         dtype,
+        keep_module_on_host,
     ):
         invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
         if invalid_test_msg:
@@ -624,13 +634,20 @@ def test_odd_world_size(
                         framework="pt")
         bs_output = pipe(query, **inf_kwargs)
 
-        pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
+        pipe.model = deepspeed.init_inference(pipe.model,
+                                              mp_size=world_size,
+                                              dtype=dtype,
+                                              keep_module_on_host=keep_module_on_host)
         ds_output = pipe(query, **inf_kwargs)
 
         print(local_rank, "baseline", bs_output)
         print(local_rank, "deepspeed", ds_output)
         assert assert_fn(bs_output, ds_output)
 
+        if keep_module_on_host:
+            for name, param in model.named_parameters():
+                assert param.device == torch.device('cpu'), f"keep_module_on_host is on but param {name} is not on cpu"
+
 
 @pytest.mark.nightly
 @pytest.mark.parametrize(

From 05eaf3d1cab0f42f130a153802c7b94d86ecc872 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?=
 <45557362+qgallouedec@users.noreply.github.com>
Date: Wed, 15 Jan 2025 23:08:56 +0100
Subject: [PATCH 2/5] `warn` to `warning` (#6952)

`warn` is deprecated, see
https://docs.python.org/3/library/logging.html#logging.Logger.warning


```DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead```
---
 accelerator/real_accelerator.py         | 2 +-
 deepspeed/runtime/base_optimizer.py     | 2 +-
 deepspeed/runtime/engine.py             | 4 ++--
 deepspeed/runtime/lr_schedules.py       | 2 +-
 deepspeed/runtime/zero/stage_1_and_2.py | 2 +-
 5 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/accelerator/real_accelerator.py b/accelerator/real_accelerator.py
index ced9218d7aca..eb4e17850882 100644
--- a/accelerator/real_accelerator.py
+++ b/accelerator/real_accelerator.py
@@ -178,7 +178,7 @@ def get_accelerator():
         if accelerator_name is None:
             # borrow this log from PR#5084
             if accel_logger is not None:
-                accel_logger.warn(
+                accel_logger.warning(
                     "Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.")
             # cpu added as catch-all when accelerator detection fails
             accelerator_name = "cpu"
diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py
index b8df7499450d..d2c54155da89 100644
--- a/deepspeed/runtime/base_optimizer.py
+++ b/deepspeed/runtime/base_optimizer.py
@@ -28,7 +28,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
 
         tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
         if self.mpu is None:
-            logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
+            logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.")
             tp_world_size = 1
         else:
             tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py
index 5f023d87f375..9b9a2e509d61 100755
--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -3120,7 +3120,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag):
                 if bf16_mode is not self.bfloat16_enabled():
                     checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
                     engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
-                    logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
+                    logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
                 return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)
 
         return None
@@ -3276,7 +3276,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa
 
                     local_expert_id = None
                     if not m:
-                        logger.warn(f'No expert found in key {key}.')
+                        logger.warning(f'No expert found in key {key}.')
                     else:
                         local_expert_id = m.group(1)
 
diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py
index 899358e2c5ef..2ffd0bf9f036 100755
--- a/deepspeed/runtime/lr_schedules.py
+++ b/deepspeed/runtime/lr_schedules.py
@@ -508,7 +508,7 @@ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, l
     def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
         if 'betas' not in optimizer.defaults:
             optimizer_name = type(optimizer).__name__
-            logger.warn(
+            logger.warning(
                 f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
             )
             self.cycle_momentum = False
diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py
index ecb2a527f870..0508766f8896 100755
--- a/deepspeed/runtime/zero/stage_1_and_2.py
+++ b/deepspeed/runtime/zero/stage_1_and_2.py
@@ -614,7 +614,7 @@ def _configure_moe_settings(self):
             assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
         # NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
         if not self.partition_gradients and not self.contiguous_gradients:
-            logger.warn(
+            logger.warning(
                 "ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
         assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
 

From 018ece5af2d89a11a4a235f81f94496c78b4f990 Mon Sep 17 00:00:00 2001
From: Xia Weiwen <xia.weiwen@hotmail.com>
Date: Thu, 16 Jan 2025 10:11:07 -0800
Subject: [PATCH 3/5] Add extra_repr to Linear classes for debugging purpose
 (#6954)

**Summary**
This PR adds `extra_repr` method to some Linear classes so that
additional info is printed when printing such modules. It is useful for
debugging.
Affected modules:
- LinearLayer
- LinearAllreduce
- LmHeadLinearAllreduce

The `extra_repr` method gives the following info:
- in_features
- out_features
- bias (true or false)
- dtype

**Example**
Print llama-2-7b model on rank 0 after `init_inference` with world size
= 2.
Previously we only got class names of these modules:
```
InferenceEngine(
  (module): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): LinearLayer()
            (k_proj): LinearLayer()
            (v_proj): LinearLayer()
            (o_proj): LinearAllreduce()
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): LinearLayer()
            (up_proj): LinearLayer()
            (down_proj): LinearAllreduce()
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): LmHeadLinearAllreduce()
  )
)
```
Now we get more useful info:
```
InferenceEngine(
  (module): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
            (k_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
            (v_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
            (o_proj): LinearAllreduce(in_features=2048, out_features=4096, bias=False, dtype=torch.bfloat16)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16)
            (up_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16)
            (down_proj): LinearAllreduce(in_features=5504, out_features=4096, bias=False, dtype=torch.bfloat16)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): LmHeadLinearAllreduce(in_features=2048, out_features=32000, bias=False, dtype=torch.bfloat16)
  )
)
```
---
 deepspeed/module_inject/layers.py | 21 +++++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py
index 722ba413a671..2f884ba4fb09 100644
--- a/deepspeed/module_inject/layers.py
+++ b/deepspeed/module_inject/layers.py
@@ -91,6 +91,13 @@ def forward(self, input):
             output += self.bias
         return output
 
+    def extra_repr(self):
+        out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
+        dtype = self.weight.dtype if self.weight is not None else None
+        extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
+            in_features, out_features, self.bias is not None, dtype)
+        return extra_repr_str
+
 
 class LmHeadLinearAllreduce(nn.Module):
 
@@ -120,6 +127,13 @@ def forward(self, input):
             output += self.bias
         return output
 
+    def extra_repr(self):
+        out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
+        dtype = self.weight.dtype if self.weight is not None else None
+        extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
+            in_features, out_features, self.bias is not None, dtype)
+        return extra_repr_str
+
 
 class LinearLayer(nn.Module):
 
@@ -144,6 +158,13 @@ def forward(self, input):
             output += self.bias
         return output
 
+    def extra_repr(self):
+        out_features, in_features = self.weight.shape
+        dtype = self.weight.dtype
+        extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
+            in_features, out_features, self.bias is not None, dtype)
+        return extra_repr_str
+
 
 class Normalize(nn.Module):
 

From f97f0885cf942aab1559d0f6a060d4801bff3a61 Mon Sep 17 00:00:00 2001
From: Logan Adams <114770087+loadams@users.noreply.github.com>
Date: Fri, 17 Jan 2025 09:43:51 -0800
Subject: [PATCH 4/5] Update import for torchvision.transformers (#6958)

Fixes import - found via
[torchfix](https://github.com/pytorch-labs/torchfix).
---
 tests/unit/alexnet_model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py
index 51e80e7f9e62..6fe84edf4eda 100644
--- a/tests/unit/alexnet_model.py
+++ b/tests/unit/alexnet_model.py
@@ -84,7 +84,7 @@ def cast_to_half(x):
 
 def cifar_trainset(fp16=False):
     torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
-    import torchvision.transforms as transforms
+    from torchvision import transforms
 
     transform_list = [
         transforms.ToTensor(),

From 7f3d669b40f8d29010efd9578d4a2cdd0f16b20e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Hyogeun=20Oh=20=28=EC=98=A4=ED=9A=A8=EA=B7=BC=29?=
 <ohg3417@gmail.com>
Date: Sat, 18 Jan 2025 02:44:49 +0900
Subject: [PATCH 5/5] Remove Duplicate Declaration of pandas in `Dockerfile`
 (#6959)

### Description

This pull request removes the redundant installation of `pandas` from
the `Dockerfile`.
It was previously declared twice, and this update eliminates the
duplicate entry, improving the clarity and maintainability of the
`Dockerfile`.


https://github.com/microsoft/DeepSpeed/blob/018ece5af2d89a11a4a235f81f94496c78b4f990/docker/Dockerfile#L124


https://github.com/microsoft/DeepSpeed/blob/018ece5af2d89a11a4a235f81f94496c78b4f990/docker/Dockerfile#L135

### Changes

Removed the duplicate pandas installation line from the `RUN pip
install` command.
---
 docker/Dockerfile | 1 -
 1 file changed, 1 deletion(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 035a094d0051..5a62a5a01aba 100755
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -132,7 +132,6 @@ RUN pip install psutil \
         sentencepiece \
         msgpack \
         requests \
-        pandas \
         sphinx \
         sphinx_rtd_theme \
         scipy \