Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(job_shop): implement dense schedule generator #115

Closed
wants to merge 25 commits into from

Conversation

dluo96
Copy link
Contributor

@dluo96 dluo96 commented Apr 5, 2023

Closes #90.

  • Implemented the DenseGenerator, which generates an instance of the job shop scheduling problem with a specified schedule length (aka makespan). This is useful for benchmarking RL agents because the optimal return is known in advance.
  • Wrote unit tests for the generator.
  • Tested this generator with the job shop A2C agent.

@dluo96 dluo96 added the enhancement New feature or request label Apr 5, 2023
@dluo96 dluo96 requested a review from clement-bonnet April 5, 2023 08:39
@dluo96 dluo96 self-assigned this Apr 5, 2023
@clement-bonnet
Copy link
Collaborator

Hi @dluo96, thanks for the contribution! Can you please update the default generator to this new one, and update the docstring, registration and readme accordingly?

@dluo96
Copy link
Contributor Author

dluo96 commented Apr 18, 2023

Thank you @clement-bonnet! I have updated the default generator, registered a new version of the JobShop environment, and also updated the README accordingly 👍

key, t, prev_col = carry
key, job_key, reuse_key = jax.random.split(key, num=3)

def reuse_prev_col(key: chex.PRNGKey, prev_col: chex.Array) -> chex.Array:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it works, but it feels like this function could be vectorized to be a bit faster and more readable, something like

            def reuse_prev_col(key: chex.PRNGKey, prev_col: chex.Array) -> chex.Array:
                reuse_key, job_key = jax.random.split(key, num=2)
                # Reuse the previous job with probability 0.7 for each machine
                reuse_jobs = jax.random.uniform(reuse_key, shape=(self.num_machines,)) > 0.3
                job_mask = jnp.ones_like(all_job_ids).at[prev_col[jnp.where(reuse_jobs)]].set(0)
                return jnp.where(
                    reuse_jobs,
                    prev_col,
                    jax.random.choice(
                        job_key,
                        all_job_ids,
                        (self.num_machines,),
                        p=job_mask,
                        replace=False,
                    ),
                )

What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this but unfortunately I get some JAX tracing errors on this line:

job_mask = jnp.ones_like(all_job_ids).at[prev_col[jnp.where(reuse_jobs)]].set(0)

)

# Reuse the previous column with probability 0.6
reuse = jax.random.uniform(reuse_key, shape=()) > 0.4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should 0.4 be an attribute, and be included in the docs ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I reckon that's a good idea

_job_mask = _job_mask.at[prev_job_id].set(True)

# Reuse the previous job with probability 0.7
reuse_op = jax.random.uniform(reuse_op_key, shape=()) > 0.3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should 0.3 be an attribute, and be included in the docs ?

job_key,
prev_col,
)
carry = key, t + 1, col
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the t used somewhere ?

shortest possible length of the schedule.
"""
del max_op_duration
del max_num_ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment to explain why they are deleted ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, have updated the doc string - let me know if you're happy with that

# Carry the job id
init_job_id = 0
job_id, (ops_mids, ops_durs) = jax.lax.scan(
get_job_info, init_job_id, xs=None, length=self.num_jobs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not xs = jnp.arange(self.num_jobs) instead of having job_id + 1 in the carry ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use vmap instead of scan to avoid sequential evaluation ?

jumanji/environments/packing/job_shop/generator_test.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@clement-bonnet clement-bonnet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this generator. Could we refactor the code to reduce the amount of nesting? Also, I suggest adding the generator as a possible option but not defaulting to it in the env init. We can default to it and up the version in another PR after testing a complete training with it (and potentially adapt the network). What do you think?

Comment on lines +212 to +213
max_num_ops: int,
max_op_duration: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have to have these in the init?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the way that the Generator interface is currently defined, we would have to, yes. This is because the __init__ takes those arguments.

# While loop in case op has duration > 1
init_val = (mask, machine_id, t_start + 1)

def next_is_same_op(val: Tuple) -> chex.Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 4 levels of nested functions. Would it be possible to refactor the code with less nesting?

jumanji/environments/packing/job_shop/generator.py Outdated Show resolved Hide resolved
key, job_key, reuse_key = jax.random.split(key, num=3)

def reuse_prev_col(key: chex.PRNGKey, prev_col: chex.Array) -> chex.Array:
def _maybe_reuse_op(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 4 levels of nesting. Could we refactor the code to avoid this?

@@ -82,7 +82,7 @@ problems.
| 💣 Minesweeper | Logic | `Minesweeper-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/minesweeper/) | [doc](https://instadeepai.github.io/jumanji/environments/minesweeper/) |
| 🎲 RubiksCube | Logic | `RubiksCube-v0`<br/>`RubiksCube-partly-scrambled-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/logic/rubiks_cube/) | [doc](https://instadeepai.github.io/jumanji/environments/rubiks_cube/) |
| 📦 BinPack (3D BinPacking Problem) | Packing | `BinPack-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/bin_pack/) | [doc](https://instadeepai.github.io/jumanji/environments/bin_pack/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
| 🏭 JobShop (Job Shop Scheduling Problem) | Packing | `JobShop-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/packing/job_shop/) | [doc](https://instadeepai.github.io/jumanji/environments/job_shop/) |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we test the new generator with a full training, we should actually keep the previous generator as default (and keep it to version v0).

@clement-bonnet clement-bonnet changed the title feat(jobshop): implement dense schedule generator feat(job_shop): implement dense schedule generator May 11, 2023
@clement-bonnet
Copy link
Collaborator

Closing this for now, will re-open after v0.3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

feat(jobshop): create instance generator with known optimal solution
3 participants