diff --git a/.gitignore b/.gitignore index 4a78a8b9..609d9a38 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ __pycache__/ *.so *.log *.onnx +*.pb +*.npy /scripts/*.py !/scripts/format.py diff --git a/examples/distributed/launch.py b/examples/distributed/launch.py new file mode 100644 index 00000000..246b987c --- /dev/null +++ b/examples/distributed/launch.py @@ -0,0 +1,157 @@ +import argparse +import os +import time +import multiprocessing as mp +from refactor_graph.onnx import make_compiler +import onnx +from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path +import numpy as np +from parallel_opt import parallel_model + + +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=1, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model", type=str, required=True, help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=512, help="sequence length.") + parser.add_argument( + "--gen_std", + action="store_true", + help="whether to generate the standard results.", + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model, + args.batch_size, + args.length, + args.gen_std, + ) + + +def run_model(executor, inputs, n=10): + for i in range(executor.input_count()): + executor.set_input(i, inputs[i]) + + executor.prepare() + executor.run() + # get outputs + outputs = executor.get_output(0) + + # bench + begin = time.time() + for _ in range(n): + executor.run() + end = time.time() + avg_time = (end - begin) / n + print(f"average time: {avg_time}") + return outputs + + +def run_and_compare(name, executor): + input_ids = np.load(f"{name}_inputs.npy") + position_ids = np.arange(input_ids.shape[-1]) + results = np.load(f"{name}_results.npy") + outputs = run_model(executor, (input_ids, position_ids)) + print("outputs abs mean:", abs(outputs).mean()) + np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model: onnx.ModelProto +): + dist_name = name + "_dist" + model = parallel_model(model, world_size, rank) + extern_path = f"./{dist_name}_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + onnx.save_model( + model, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, + location=extern_path, + ) + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + + compiler = make_compiler(model, ".") + + executor = compiler.compile("cuda", "default", [], rank) + executor.set_cuda_commnication(world_size, rank) + run_and_compare(name, executor) + + +def start_single(name, model): + compiler = make_compiler(model) + executor = compiler.compile("cuda", "default", [], 0) + run_and_compare(name, executor) + + +def gen_standard(name, model, voc_size, bs, len): + # generate standard results + input_ids = np.random.randint(0, voc_size, (bs, len)).astype(np.int32) + position_ids = np.arange(len) + np.save(f"{name}_inputs", input_ids) + compiler = make_compiler(model) + executor = compiler.compile("cuda", "default", [], 0) + outputs = run_model(executor, (input_ids, position_ids), 1) + print("outputs abs mean:", abs(outputs).mean()) + np.save(f"{name}_results", outputs) + + +def main(): + nnodes, nproc_per_node, name, model_path, bs, length, gen_std = parse_args() + + model = onnx.load(model_path) + + gen_std =False + # generate standart output + if gen_std: + print(f"generate standard data for {name}.") + # a small vocabulary size to fit all LLM. + voc_size = 1000 + gen_standard(name, model, voc_size, bs, length) + return + + # run single process. + # use standalone process to isolate cuda. + print("run model by single GPU.") + p = mp.Process(target=start_single, args=(name, model)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node + print(f"run model by {world_size} GPU in parallel.") + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py new file mode 100644 index 00000000..6830af4a --- /dev/null +++ b/examples/distributed/parallel_opt.py @@ -0,0 +1,239 @@ +import onnx +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto +from onnx import helper, numpy_helper +from typing import Dict, List +from placement import Placement, Replicate, Shard, _Partial +import numpy as np + + +def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): + data = {init.name: init for init in model.graph.initializer} + vinfo = {info.name: info for info in model.graph.value_info} + vinfo.update({info.name: info for info in model.graph.input}) + vinfo.update({info.name: info for info in model.graph.output}) + output = {info.name: info for info in model.graph.output} + place: Dict[str, Placement] = {} + nodes: List[NodeProto] = [] + + def is_sharded(name: str): + return place[name].is_shard() + + def shard_tensor(tensor: TensorProto, plc: Shard, groups: int = 1): + # print(f"shard {tensor.name} at dim {dim}") + assert plc.is_shard(), plc + ndim = len(tensor.dims) + if plc.dim < 0: + plc.dim += ndim + if tensor.dims[plc.dim] == 1: # broadcast dim, no need to shard. + return tensor + array = numpy_helper.to_array(tensor) + assert array.shape[plc.dim] % tp_world_size == 0, array.shape[plc.dim] + dims = list(tensor.dims) + dims.insert(plc.dim, groups) + dims[plc.dim + 1] //= groups + array = array.reshape(dims) + seg = array.shape[plc.dim + 1] // tp_world_size + array = array.take( + indices=range(tp_rank * seg, (tp_rank + 1) * seg), axis=plc.dim + 1 + ) + dims = list(tensor.dims) + dims[plc.dim] //= tp_world_size + array = array.reshape(dims) + tensor = numpy_helper.from_array(array, name=tensor.name) + place[tensor.name] = plc + return tensor + + def shard_gemm(node: NodeProto, groups: int = 1): + # print("gemm", node.name) + in_plc = place[node.input[0]] + w_plc = Shard(-1) if in_plc.is_replicate() else Shard(0) + transB = next((attr.i for attr in node.attribute if attr.name == "transB"), 0) + if transB: + w_plc.dim = ~w_plc.dim + input = node.input[1] + data[input] = shard_tensor(data[input], w_plc, groups) + + output = node.output[0] + ndim = len(vinfo[output].type.tensor_type.shape.dim) + out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() + place[node.output[0]] = out_plc + + def shard_concat(node: NodeProto): + # hack for kvcache + in_plc = place[node.input[1]] + if in_plc.is_shard(): + seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1) + seq_len_dim.dim_value //= tp_world_size + vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim) + place[node.input[0]] = in_plc + place[node.output[0]] = in_plc + + def shard_binary(node: NodeProto, groups: int = 1): + # print("binary", node.name, node.input[0], place[node.input[0]]) + a = node.input[0] + b = node.input[1] + if a in data: + a, b = b, a + place[node.output[0]] = place[a] + if is_sharded(a) and b in data and len(data[b].dims) == 1: # broadcast + data[b] = shard_tensor(data[b], Shard(0), groups) + + def shard_reshape(node: NodeProto): + # print("reshape", node.name, node.input[0], place[node.input[0]]) + if not is_sharded(node.input[0]): + return + in_plc = place[node.input[0]] + s_dim = -1 + in_dims = [d.dim_value for d in vinfo[node.input[0]].type.tensor_type.shape.dim] + tensor = data[node.input[1]] + out_dims = numpy_helper.to_array(tensor).copy() + if len(in_dims) == 3 and len(out_dims) == 4: + if in_plc.dim == 0: + s_dim = 1 + elif in_plc.dim == 2: + s_dim = 2 + if len(in_dims) == 4 and len(out_dims) == 3: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 2 + if len(in_dims) == 2 and len(out_dims) == 3: + if in_plc.dim == 1: + s_dim = 2 + if len(in_dims) == 4 and len(out_dims) == 2: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 1 + if len(in_dims) == 3 and len(out_dims) == 2: + if in_plc.dim == 1: + s_dim = 0 + elif in_plc.dim == 2: + s_dim = 1 + + assert s_dim != -1 + assert out_dims[s_dim] % tp_world_size == 0, out_dims + out_dims[s_dim] //= tp_world_size + # if ONNX uses the same tensor for multiple Reshape Nodes, then rename it to distingush from others. + # node.input[1] = node.output[0] + "_shape" + data[node.input[1]] = numpy_helper.from_array(out_dims, name=node.input[1]) + place[node.output[0]] = Shard(s_dim) + + def shard_split(node: NodeProto): + if not is_sharded(node.input[0]): + return + in_plc = place[node.input[0]] + split_tensor = data[node.input[1]] + split = numpy_helper.to_array(split_tensor).copy() + split //= tp_world_size + data[node.input[1]] = numpy_helper.from_array(split, name=node.input[1]) + for output in node.output: + place[output] = in_plc + + def shard_transpose(node: NodeProto): + plc = place[node.input[0]] + if plc.is_shard(): + perm = next(attr.ints for attr in node.attribute if attr.name == "perm") + place[node.output[0]] = Shard(list(perm).index(plc.dim)) + + def shard_node(node: NodeProto): + if node.op_type in ["Relu", "Tanh", "Softmax"]: + place[node.output[0]] = place[node.input[0]] + elif node.op_type in ["Where"]: + place[node.output[0]] = place[node.input[1]] + if node.op_type in {"Add", "Mul", "Div", "Max"}: + shard_binary(node) + elif node.op_type == "Reshape": + shard_reshape(node) + elif node.op_type == "Transpose": + shard_transpose(node) + elif node.op_type == "Split": + shard_split(node) + elif node.op_type == "MatMul": + assert ( + place[node.input[0]] == place[node.input[1]] + ), f"{place[node.input[0]]} != {place[node.input[1]]}" + place[node.output[0]] = place[node.input[0]] + elif node.op_type == "Concat": + shard_concat(node) + + def find_successor(op_type: str, idx: int, search_limit: int = 1): + for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: + if node.op_type == op_type: + return node + return None + + # all tensors are initially replicated. + for v in vinfo: + place[v] = Replicate() + + for t in data: + place[t] = Replicate() + + for index, node in enumerate(model.graph.node): + nodes.append(node) + # linear + if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( + input in data for input in node.input + ): + # FIXME(constroy): the last MatMul should not be sharded as TP. + if node.output[0] in output: + continue + groups = 1 + # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups + split_node = find_successor("Split", index, search_limit=2) + if split_node is not None: + groups = len(split_node.output) + shard_gemm(node, groups) + plc = place[node.output[0]] + if plc.is_partial(): + new_name = node.output[0] + f":{plc}" + place[new_name] = place[node.output[0]] + # insert all_reduce + nodes.append( + helper.make_node( + op_type="AllReduceSum", + inputs=[new_name], + outputs=[node.output[0]], + name=node.name + "/all_reduce", + ) + ) + place[node.output[0]] = Replicate() + node.output[0] = new_name + if len(node.input) > 2: # split bias to add + prev = nodes[-1] + new_name = prev.output[0] + "_no_bias" + place[new_name] = place[node.output[0]] + bias = helper.make_node( + op_type="Add", + inputs=[new_name, node.input[2]], + outputs=[prev.output[0]], + name=node.name + "/bias", + ) + node.input.pop() + prev.output[0] = new_name + shard_binary(bias, groups) + nodes.append(bias) + continue + shard_node(node) + + new_input = [] + for info in model.graph.input: + new_input.append(vinfo[info.name]) + + graph = helper.make_graph( + nodes, + model.graph.name + f"_{tp_rank}", + new_input, + model.graph.output, + data.values(), + doc_string=model.graph.doc_string, + # value_info=vinfo.values(), + ) + for output in graph.output: + tt = output.type.tensor_type + if tt.HasField("shape"): + tt.ClearField("shape") + model = helper.make_model(graph) + model = onnx.shape_inference.infer_shapes(model) + return model diff --git a/examples/distributed/placement.py b/examples/distributed/placement.py new file mode 100644 index 00000000..634d4fe5 --- /dev/null +++ b/examples/distributed/placement.py @@ -0,0 +1,64 @@ +from typing import Optional + + +class Placement: + # base class Placement type + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + if dim is not None and isinstance(self, Shard): + return self.dim == dim + else: + return isinstance(self, Shard) + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, _Partial) + + +class Replicate(Placement): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Replicate): + return False + return True + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + +class Shard(Placement): + # shard placement, shard on a dim + def __init__(self, dim): + self.dim = dim + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + +class _Partial(Placement): + def __init__(self, reduce_op: str = "sum"): + self.reduce_op: str = reduce_op + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _Partial): + return False + return self.reduce_op == other.reduce_op + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"_Partial(reduce_op={self.reduce_op})" diff --git a/src/03runtime/include/runtime/stream.h b/src/03runtime/include/runtime/stream.h index d87b01d0..9e452dc3 100644 --- a/src/03runtime/include/runtime/stream.h +++ b/src/03runtime/include/runtime/stream.h @@ -44,6 +44,7 @@ namespace refactor::runtime { std::vector<_N>, std::vector<_E>); void setInput(count_t, void const *, size_t); + size_t inputCount() const; void setInput(count_t, mem_manager::SharedForeignBlob); void getOutput(count_t, void *, size_t) const; Resources &getResources() { return _resources; }; diff --git a/src/03runtime/src/stream.cc b/src/03runtime/src/stream.cc index e0e724dc..c386d304 100644 --- a/src/03runtime/src/stream.cc +++ b/src/03runtime/src/stream.cc @@ -51,6 +51,11 @@ namespace refactor::runtime { blob->copyIn(data, size); _internal.edges[globalInputs[i]].value = {std::move(blob)}; } + + size_t Stream::inputCount() const{ + return _internal.topology.globalInputs().size(); + } + void Stream::setInput(count_t i, mem_manager::SharedForeignBlob blob) { auto globalInputs = _internal.topology.globalInputs(); ASSERT(i < globalInputs.size(), "input index out of range"); diff --git a/src/04kernel/cuda/include/kernel/cuda/functions.cuh b/src/04kernel/cuda/include/kernel/cuda/functions.cuh index c611dae3..f66d202a 100644 --- a/src/04kernel/cuda/include/kernel/cuda/functions.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/functions.cuh @@ -4,6 +4,7 @@ namespace refactor::kernel::cuda { void sync(); + void setCudaDevice(int); }// namespace refactor::kernel::cuda diff --git a/src/04kernel/cuda/src/functions.cu b/src/04kernel/cuda/src/functions.cu index d34ea540..a24a9a82 100644 --- a/src/04kernel/cuda/src/functions.cu +++ b/src/04kernel/cuda/src/functions.cu @@ -6,4 +6,8 @@ namespace refactor::kernel::cuda { cudaDeviceSynchronize(); } + void setCudaDevice(int id) { + cudaSetDevice(id); + } + }// namespace refactor::kernel::cuda diff --git a/src/04kernel/src/target.cc b/src/04kernel/src/target.cc index a9dbc019..f329a5dd 100644 --- a/src/04kernel/src/target.cc +++ b/src/04kernel/src/target.cc @@ -37,7 +37,7 @@ namespace refactor::kernel { } #ifdef USE_CUDA case NvidiaGpu: { - static thread_local Arc memPool = std::make_shared(10ul << 30, 256, cuda::BasicCudaMemManager::instance()); + static thread_local Arc memPool = std::make_shared(20ul << 30, 256, cuda::BasicCudaMemManager::instance()); return memPool; } #endif diff --git a/src/04kernel/src/utilities/cuda/nccl_communicator.cu b/src/04kernel/src/utilities/cuda/nccl_communicator.cu index 9deed73e..225cc94f 100644 --- a/src/04kernel/src/utilities/cuda/nccl_communicator.cu +++ b/src/04kernel/src/utilities/cuda/nccl_communicator.cu @@ -9,8 +9,6 @@ namespace refactor::kernel::nccl { NcclCommunicator::NcclCommunicator(int worldSize, int rank) : worldSize_(worldSize), rank_(rank) { - cudaSetDevice(rank); - const std::string filePath("./nccl_id.bin"); ncclUniqueId commId; diff --git a/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp b/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp index bf6319a4..6b67a6e4 100644 --- a/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp +++ b/src/04kernel/test/kernels/all_reduce/test_allreduce_nccl.cpp @@ -2,6 +2,7 @@ #include "../src/kernels/all_reduce/nccl_kernel.hh" #include "../src/utilities/cuda/nccl_communicator.hh" +#include "kernel/cuda/functions.cuh" #include "kernel/target.h" #include #include @@ -9,8 +10,10 @@ using namespace refactor; using namespace kernel; using namespace nccl; +using namespace cuda; void allReduce(AllReduceType redType, int rank, int worldSize, std::vector data, std::vector ans) { + cuda::setCudaDevice(rank); auto input = Tensor::share(DataType::F32, Shape{2}, LayoutType::NCHW); auto output = Tensor::share(DataType::F32, Shape{2}, LayoutType::NCHW); auto kernel = AllReduceNccl::build(redType, *input, *output); diff --git a/src/08communication/src/operators/all_reduce.cc b/src/08communication/src/operators/all_reduce.cc index 1784b581..d3891128 100644 --- a/src/08communication/src/operators/all_reduce.cc +++ b/src/08communication/src/operators/all_reduce.cc @@ -1,6 +1,6 @@ #include "all_reduce.hh" #include "common.h" - +#include "computation/operators/all_reduce.h" namespace refactor::communication { using Op = AllReduce; @@ -74,7 +74,9 @@ namespace refactor::communication { extractDependency(inputs))}); } - computation::OpBox Op::lower(TensorRefs) const { + computation::OpBox Op::lower(TensorRefs inputs) const { + + return std::make_unique(type); } }// namespace refactor::communication diff --git a/src/09python_ffi/src/compiler.cc b/src/09python_ffi/src/compiler.cc index 8a158f97..f131cdd3 100644 --- a/src/09python_ffi/src/compiler.cc +++ b/src/09python_ffi/src/compiler.cc @@ -1,6 +1,7 @@ #include "compiler.h" #include "common.h" #include "kernel/allocators.h" +#include "kernel/cuda/functions.cuh" #include namespace refactor::python_ffi { @@ -30,7 +31,8 @@ namespace refactor::python_ffi { Arc Compiler::compile(std::string target, std::string allocator, - std::vector passes) { + std::vector passes, + int deviceID) { _g.collectVariables(); std::vector unknownVariables; for (auto const &[_, v] : _g.variables()) { @@ -64,6 +66,9 @@ namespace refactor::python_ffi { target_ = kernel::Target::Cpu; } else if (target == "cuda") { target_ = kernel::Target::NvidiaGpu; + if (deviceID >= 0) { + kernel::cuda::setCudaDevice(deviceID); + } } else { UNREACHABLE(); } diff --git a/src/09python_ffi/src/compiler.h b/src/09python_ffi/src/compiler.h index 44f3232e..517af292 100644 --- a/src/09python_ffi/src/compiler.h +++ b/src/09python_ffi/src/compiler.h @@ -20,7 +20,8 @@ namespace refactor::python_ffi { Arc compile( std::string target, std::string allocator, - std ::vector passes); + std ::vector passes, + int deviceID = 0); std::optional getTensor(CStr) const; }; diff --git a/src/09python_ffi/src/executor.cc b/src/09python_ffi/src/executor.cc index 3a53dcde..8c5ff62f 100644 --- a/src/09python_ffi/src/executor.cc +++ b/src/09python_ffi/src/executor.cc @@ -1,8 +1,8 @@ #include "executor.h" #ifdef USE_CUDA -#include "kernel/cuda/functions.cuh" #include "../../04kernel/src/utilities/cuda/nccl_communicator.hh" +#include "kernel/cuda/functions.cuh" #endif// USE_CUDA namespace refactor::python_ffi { @@ -15,6 +15,10 @@ namespace refactor::python_ffi { _stream.setInput(i, data.data(), data.nbytes()); } + size_t Executor::inputCount() const { + return _stream.inputCount(); + } + auto Executor::getOutput(count_t i) -> pybind11::array { auto globalOutputs = _graph.internal().contiguous().topology.globalOutputs(); ASSERT(i < globalOutputs.size(), "input index out of range"); diff --git a/src/09python_ffi/src/executor.h b/src/09python_ffi/src/executor.h index 7f8c250f..b634cab3 100644 --- a/src/09python_ffi/src/executor.h +++ b/src/09python_ffi/src/executor.h @@ -14,6 +14,7 @@ namespace refactor::python_ffi { public: Executor(computation::Graph, runtime::Stream); void setInput(count_t, pybind11::array); + size_t inputCount() const; auto getOutput(count_t) -> pybind11::array; auto prepare() -> std::vector; void run(); diff --git a/src/09python_ffi/src/main.cpp b/src/09python_ffi/src/main.cpp index 11091d8b..b3703455 100644 --- a/src/09python_ffi/src/main.cpp +++ b/src/09python_ffi/src/main.cpp @@ -35,6 +35,7 @@ namespace refactor::python_ffi { py::class_>(m, "Executor" ) .def("set_input" , &Executor::setInput , return_::automatic ) + .def("input_count" , &Executor::inputCount , return_::automatic ) .def("get_output" , &Executor::getOutput , return_::move ) .def("prepare" , &Executor::prepare , return_::move ) .def("run" , &Executor::run , return_::automatic )