Skip to content

Commit

Permalink
fix test_inference_engine
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 20, 2024
1 parent 56e5e34 commit 30ab126
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions exo/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
async def test_inference_engine(inference_engine_1: InferenceEngine, inference_engine_2: InferenceEngine, model_id: str):
prompt = "In a single word only, what is the capital of Japan? "
prompt = "In a single word only, what is the last name of the current president of the USA?"
resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), prompt=prompt)
next_resp_full, next_inference_state_full, _ = await inference_engine_1.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=31, n_layers=32), input_data=resp_full, inference_state=inference_state_full)

Expand All @@ -33,5 +33,5 @@ async def test_inference_engine(inference_engine_1: InferenceEngine, inference_e
asyncio.run(test_inference_engine(
TinygradDynamicShardInferenceEngine(),
TinygradDynamicShardInferenceEngine(),
"/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
"llama3-8b-sfr",
))

0 comments on commit 30ab126

Please sign in to comment.