-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(job_shop): implement dense schedule generator #115
Conversation
…-known-makespan-generator
…adeepai/jumanji into 90-jobshop-known-makespan-generator
Hi @dluo96, thanks for the contribution! Can you please update the default generator to this new one, and update the docstring, registration and readme accordingly? |
…adeepai/jumanji into 90-jobshop-known-makespan-generator
Thank you @clement-bonnet! I have updated the default generator, registered a new version of the |
key, t, prev_col = carry | ||
key, job_key, reuse_key = jax.random.split(key, num=3) | ||
|
||
def reuse_prev_col(key: chex.PRNGKey, prev_col: chex.Array) -> chex.Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if it works, but it feels like this function could be vectorized to be a bit faster and more readable, something like
def reuse_prev_col(key: chex.PRNGKey, prev_col: chex.Array) -> chex.Array:
reuse_key, job_key = jax.random.split(key, num=2)
# Reuse the previous job with probability 0.7 for each machine
reuse_jobs = jax.random.uniform(reuse_key, shape=(self.num_machines,)) > 0.3
job_mask = jnp.ones_like(all_job_ids).at[prev_col[jnp.where(reuse_jobs)]].set(0)
return jnp.where(
reuse_jobs,
prev_col,
jax.random.choice(
job_key,
all_job_ids,
(self.num_machines,),
p=job_mask,
replace=False,
),
)
What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried this but unfortunately I get some JAX tracing errors on this line:
job_mask = jnp.ones_like(all_job_ids).at[prev_col[jnp.where(reuse_jobs)]].set(0)
) | ||
|
||
# Reuse the previous column with probability 0.6 | ||
reuse = jax.random.uniform(reuse_key, shape=()) > 0.4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should 0.4 be an attribute, and be included in the docs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I reckon that's a good idea
_job_mask = _job_mask.at[prev_job_id].set(True) | ||
|
||
# Reuse the previous job with probability 0.7 | ||
reuse_op = jax.random.uniform(reuse_op_key, shape=()) > 0.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should 0.3 be an attribute, and be included in the docs ?
job_key, | ||
prev_col, | ||
) | ||
carry = key, t + 1, col |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the t used somewhere ?
shortest possible length of the schedule. | ||
""" | ||
del max_op_duration | ||
del max_num_ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a comment to explain why they are deleted ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, have updated the doc string - let me know if you're happy with that
# Carry the job id | ||
init_job_id = 0 | ||
job_id, (ops_mids, ops_durs) = jax.lax.scan( | ||
get_job_info, init_job_id, xs=None, length=self.num_jobs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not xs = jnp.arange(self.num_jobs)
instead of having job_id + 1
in the carry ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use vmap
instead of scan to avoid sequential evaluation ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this generator. Could we refactor the code to reduce the amount of nesting? Also, I suggest adding the generator as a possible option but not defaulting to it in the env init. We can default to it and up the version in another PR after testing a complete training with it (and potentially adapt the network). What do you think?
max_num_ops: int, | ||
max_op_duration: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have to have these in the init?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the way that the Generator
interface is currently defined, we would have to, yes. This is because the __init__
takes those arguments.
# While loop in case op has duration > 1 | ||
init_val = (mask, machine_id, t_start + 1) | ||
|
||
def next_is_same_op(val: Tuple) -> chex.Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are 4 levels of nested functions. Would it be possible to refactor the code with less nesting?
key, job_key, reuse_key = jax.random.split(key, num=3) | ||
|
||
def reuse_prev_col(key: chex.PRNGKey, prev_col: chex.Array) -> chex.Array: | ||
def _maybe_reuse_op( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are 4 levels of nesting. Could we refactor the code to avoid this?
@@ -82,7 +82,7 @@ problems. | |||
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) | | |||
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) | | |||
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) | | |||
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) | | |||
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless we test the new generator with a full training, we should actually keep the previous generator as default (and keep it to version v0).
Closing this for now, will re-open after v0.3 |
Closes #90.
DenseGenerator
, which generates an instance of the job shop scheduling problem with a specified schedule length (aka makespan). This is useful for benchmarking RL agents because the optimal return is known in advance.