Skip to content

Commit

Permalink
Merge pull request #114 from edgeintelligence/brushup
Browse files Browse the repository at this point in the history
Brushup
  • Loading branch information
milhidaka authored Mar 17, 2022
2 parents 3dfa2b4 + 0396da7 commit 607c505
Show file tree
Hide file tree
Showing 20 changed files with 337 additions and 52 deletions.
22 changes: 22 additions & 0 deletions distributed/README.ja.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# kakiage 分散機械学習サーバ

サーバ側は Python ライブラリとして実装されている。

# セットアップ

Python 3.8+

```
pip install -r requirements.txt
python setup.py develop
```

サンプル動作方法: `samples/*/README.md`参照

# 配布用ビルド

```
python setup.py bdist_wheel
```

`dist/kakiage-<version>-py3-none-any.whl` が生成される。利用者は、`pip install /path/to/kakiage-<version>-py3-none-any.whl`を実行することで必須依存パッケージ(numpy 等)とともに kakiage をインストールすることが可能。
12 changes: 6 additions & 6 deletions distributed/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# kakiage 分散機械学習サーバ
# kakiage distributed training server

サーバ側は Python ライブラリとして実装されている。
Server-side code is implemented as Python library.

# セットアップ
# Setup

Python 3.8+

Expand All @@ -11,12 +11,12 @@ pip install -r requirements.txt
python setup.py develop
```

サンプル動作方法: `samples/*/README.md`参照
How to run sample: see `samples/*/README.md`

# 配布用ビルド
# Build for distribution

```
python setup.py bdist_wheel
```

`dist/kakiage-<version>-py3-none-any.whl` が生成される。利用者は、`pip install /path/to/kakiage-<version>-py3-none-any.whl`を実行することで必須依存パッケージ(numpy 等)とともに kakiage をインストールすることが可能。
`dist/kakiage-<version>-py3-none-any.whl` will be generated. The user runs `pip install /path/to/kakiage-<version>-py3-none-any.whl` to install kakiage along with required dependencies (numpy, etc.).
39 changes: 39 additions & 0 deletions distributed/sample/mnist_data_parallel/README.ja.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Sample: MNIST data parallel training

MNIST画像データセットを分類するMLPを、データ並列方式により分散学習

CIFAR10, CIFAR100データセットも使用可能。

# ビルド

```
npm install
npm run build
```

# 学習実行

環境変数で設定を行う。

- MODEL: mlp, conv, resnet18のいずれか。モデルの種類を指定する。
- N_CLIENTS: 分散計算に参加するクライアント数。1以上の整数を指定する。指定しない場合は1が指定されたとみなす。
- EPOCH: 学習エポック数。デフォルトは2。
- BATCH_SIZE: バッチサイズ。全クライアントの合計。デフォルトは32。

実行はuvicorn経由で行う。コマンド例(Mac/Linuxの場合):

```
MODEL=conv N_CLIENTS=2 npm run train
```

Windowsの場合はsetコマンドを使用して以下のようになる:

```
set MODEL=conv
set N_CLIENTS=2
npm run train
```

ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。`N_CLIENTS`を設定した場合、並列で計算するため、`N_CLIENTS`個のブラウザウィンドウで開く必要がある。注意:1つのウィンドウ上で複数のタブを開いた場合、表示されていないタブの計算速度が低下する。

学習したモデルはONNXフォーマットで出力される。WebDNN、ONNX Runtime Web等により、推論に利用することができる。
31 changes: 26 additions & 5 deletions distributed/sample/mnist_data_parallel/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
# サンプル: MNIST data parallel training
# Sample: MNIST data parallel training

MNIST画像データセットを分類するMLPを、データ並列で学習する
Training MLPs to classify MNIST image datasets in data-parallel distributed training

# ビルド
CIFAR10, CIFAR100 dataset can be also used.

# Build

```
npm install
npm run build
```

# 学習実行
# Run training

Settings are made via environment variables.

- MODEL: one of mlp, conv, resnet18. Specify model type.
- N_CLIENTS: The number of clients participating in the distribution calculation, an integer greater than or equal to 1. If not specified, 1 is assumed to be specified.
- EPOCH: Number of learning epochs. Default is 2.
- BATCH_SIZE: Batch size. Total for all clients. Default is 32.

Execution is via uvicorn. Command sample (for Mac/Linux):

```
MODEL=conv N_CLIENTS=2 npm run train
```

On Windows, use the set command:

```
set MODEL=conv
set N_CLIENTS=2
npm run train
```

ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。3並列で計算するため、3つのタブで開く必要がある。
Open [http://localhost:8081/](http://localhost:8081/) with web browser. If you set `N_CLIENTS`, to run `N_CLIENTS` distributed clients, it must be opened in `N_CLIENTS` browser windows. Note: If three tabs are opened on one window, the computation speed of the tabs not displayed will be reduced.

The learned models are output in ONNX format and can be used for inference with WebDNN, ONNX Runtime Web, etc.
14 changes: 10 additions & 4 deletions distributed/sample/mnist_data_parallel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@
from kakiage.tensor_serializer import serialize_tensors_to_bytes, deserialize_tensor_from_bytes
from sample_net import make_net, get_io_shape, get_dataset_loader

# スクリプトの配布
# setup server to distribute javascript and communicate
kakiage_server = setup_server()
app = kakiage_server.app

# PyTorchを用いた初期モデルの作成、学習したモデルのサーバサイドでの評価


def test(model, loader):
model.eval()
loss_sum = 0.0
Expand All @@ -56,6 +53,7 @@ def test(model, loader):
def snake2camel(name):
"""
running_mean -> runningMean
PyTorch uses snake_case, kakiage uses camelCase
"""
upper = False
cs = []
Expand Down Expand Up @@ -99,6 +97,7 @@ async def main():
client_ids = []
print(f"Waiting {n_client_wait} clients to connect")

# Gets server event
async def get_event():
while True:
event = await kakiage_server.event_queue.get()
Expand Down Expand Up @@ -137,11 +136,13 @@ async def get_event():
chunk_size = math.ceil(batch_size / n_clients)
chunk_sizes = []
grad_item_ids = []
# split batch into len(client_ids) chunks
for c, client_id in enumerate(client_ids):
image_chunk = image[c*chunk_size:(c+1)*chunk_size]
label_chunk = label[c*chunk_size:(c+1)*chunk_size]
chunk_sizes.append(len(image_chunk))
dataset_item_id = uuid4().hex
# set blob (binary data) in server so that client can download by spceifying id
kakiage_server.blobs[dataset_item_id] = serialize_tensors_to_bytes(
{
"image": image_chunk.detach().numpy().astype(np.float32),
Expand All @@ -150,6 +151,7 @@ async def get_event():
)
item_ids_to_delete.append(dataset_item_id)
grad_item_id = uuid4().hex
# send client to calculate gradient given the weight and batch
await kakiage_server.send_message(client_id, {
"model": model_name,
"inputShape": list(input_shape),
Expand All @@ -161,6 +163,9 @@ async def get_event():
grad_item_ids.append(grad_item_id)
item_ids_to_delete.append(grad_item_id)
complete_count = 0
# Wait for all clients to complete
# No support for disconnection and dynamic addition of clients (this implementation waits disconnected client forever)
# To support, handle event such as KakiageServerWSConnectEvent
while True:
event = await get_event()
if isinstance(event, KakiageServerWSReceiveEvent):
Expand All @@ -184,6 +189,7 @@ async def get_event():
for k, v in weights.items():
grad = grad_arrays[k]
if is_trainable_key(k):
# update weight using SGD (no momentum)
v -= lr * grad
else:
# not trainable = BN stats = average latest value
Expand Down
39 changes: 29 additions & 10 deletions distributed/sample/mnist_data_parallel/public/index.html
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<head>
<meta charset="UTF-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Kakiage Distributed MNIST Training Sample</title>
<script src="static/index.js"></script>
</head>
<body>
<p id="state"></p>
<p id="messages"></p>
</body>
</html>
<link href="static/index.css" rel="stylesheet" />
</head>
<body>
<h1>Kakiage Distributed Training</h1>
<main>
<p id="state"></p>
<table>
<tbody>
<tr>
<td>Processed batches</td>
<td id="table-batches"></td>
</tr>
<tr>
<td>Last loss</td>
<td id="table-loss"></td>
</tr>
<tr>
<td>Batch size</td>
<td id="table-batchsize"></td>
</tr>
</tbody>
</table>
</main>
</body>
</html>
13 changes: 13 additions & 0 deletions distributed/sample/mnist_data_parallel/public/static/index.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
body {
margin: 0;
}

h1 {
background: linear-gradient(180deg, orange, white);
margin: 0;
padding: 1em;
}

main {
padding: 1em;
}
21 changes: 10 additions & 11 deletions distributed/sample/mnist_data_parallel/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,16 @@ function makeModel(
}
}

function writeLog(message: string) {
document.getElementById('messages')!.innerText += message + '\n';
}

const writeState = throttle((message: string) => {
document.getElementById('state')!.innerText = message;
}, 1000);

const writeBatchInfo = throttle((processedBatches: number, lastLoss: number, batchSize: number) => {
document.getElementById('table-batches')!.innerText = processedBatches.toString();
document.getElementById('table-loss')!.innerText = lastLoss.toString();
document.getElementById('table-batchsize')!.innerText = batchSize.toString();
}, 1000);

async function sendBlob(itemId: string, data: Uint8Array): Promise<void> {
const blob = new Blob([data]);
const f = await fetch(`/kakiage/blob/${itemId}`, {
Expand Down Expand Up @@ -285,23 +287,21 @@ async function compute(msg: { weight: string; dataset: string; grad: string }) {
}
await sendBlob(msg.grad, new TensorSerializer().serialize(grads));
totalBatches += 1;
writeState(
`total batch: ${totalBatches}, last loss: ${lossValue}, last batch size: ${y.data.shape[0]}`
);
writeBatchInfo(totalBatches, lossValue, y.data.shape[0]);
}

async function run() {
writeState('Connecting');
writeState('Connecting to distributed training server...');
ws = new WebSocket(
(window.location.protocol === 'https:' ? 'wss://' : 'ws://') +
window.location.host +
'/kakiage/ws'
);
ws.onopen = () => {
writeState('Connected to WS server');
writeState('Connected to server');
};
ws.onclose = () => {
writeState('Disconnected from WS server');
writeState('Disconnected from server');
};
ws.onmessage = async (ev) => {
const msg = JSON.parse(ev.data);
Expand All @@ -320,7 +320,6 @@ async function run() {
window.addEventListener('load', async () => {
backend = (new URLSearchParams(window.location.search).get('backend') ||
'webgl') as K.Backend;
writeLog(`backend: ${backend}`);
if (backend === 'webgl') {
await K.tensor.initializeNNWebGLContext();
}
Expand Down
18 changes: 18 additions & 0 deletions distributed/sample/multiply/README.ja.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# サンプル: テンソルの定数倍

テンソルを定数倍にして返す、シンプルなサンプル

# ビルド

```
npm install
npm run build
```

# 学習実行

```
npm run train
```

ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。
10 changes: 5 additions & 5 deletions distributed/sample/multiply/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# サンプル: テンソルの定数倍
# Sample: Constant times the tensor

テンソルを定数倍にして返す、シンプルなサンプル
Simple sample that returns a tensor times a constant

# ビルド
# Build

```
npm install
npm run build
```

# 学習実行
# Run

```
npm run train
```

ブラウザで[http://localhost:8081/](http://localhost:8081/)を開く。
Open [http://localhost:8081/](http://localhost:8081/) with web browser.
2 changes: 1 addition & 1 deletion distributed/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
setup(
name="kakiage",
packages=find_packages(),
version="0.0.1",
version="1.0.0",
install_requires=["numpy", "fastapi", "uvicorn[standard]"]
)
2 changes: 2 additions & 0 deletions sample/mnist_train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This sample describes how to
- train and evaluate the model
- save and load the model

See comments in the source code.

# Build

```
Expand Down
Loading

0 comments on commit 607c505

Please sign in to comment.