Skip to content

Commit

Permalink
Add typestate-MGD support with unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
LakshyAAAgrawal committed Nov 21, 2023
1 parent fa50b39 commit 3751934
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ A monitor under the Monitor-Guided Decoding framework, is instantiated using `mu
#### Monitor for valid number of arguments to function calls
[src/monitors4codegen/monitor_guided_decoding/monitors/numargs_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/numargs_monitor.py) provides the instantiation of `Monitor` class for numargs_monitor. It can be used to guide LMs to generate correct number of arguments to function calls. Unit tests, which also provide usage examples are present in [tests/monitor_guided_decoding/test_numargs_monitor_java.py](tests/monitor_guided_decoding/test_numargs_monitor_java.py).

### Monitor for typestate specifications
The typestate analysis is used to enforce that methods on an object are called in a certain order, consistent with the ordering constraints provided by the API contracts. Example usage of the typestate monitor for Rust is available in the unit test file [tests/monitor_guided_decoding/test_typestate_monitor_rust.py](tests/monitor_guided_decoding/test_typestate_monitor_rust.py).

#### Switch-Enum Monitor
[src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py](src/monitors4codegen/monitor_guided_decoding/monitors/switch_enum_monitor.py) provides the instantiation of `Monitor` for generating valid named enum constants in C#. Unit tests for the switch-enum monitor are present in [tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py](tests/monitor_guided_decoding/test_switchenum_monitor_csharp.py), which also provide usage examples for the switch-enum monitor.

Expand Down
110 changes: 110 additions & 0 deletions tests/monitor_guided_decoding/test_typestate_monitor_rust.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
This file contains tests for MGD for typestate
"""

import pytest
import torch
import transformers

from pathlib import PurePath
from monitors4codegen.multilspy.language_server import SyncLanguageServer
from monitors4codegen.multilspy.multilspy_config import Language
from tests.test_utils import create_test_context, is_cuda_available
from transformers import AutoTokenizer, AutoModelForCausalLM
from monitors4codegen.multilspy.multilspy_utils import TextUtils
from monitors4codegen.monitor_guided_decoding.monitors.dereferences_monitor import DereferencesMonitor
from monitors4codegen.monitor_guided_decoding.monitor import MonitorFileBuffer
from monitors4codegen.monitor_guided_decoding.hf_gen import MGDLogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from monitors4codegen.multilspy.multilspy_types import Position
from monitors4codegen.monitor_guided_decoding.tokenizer_wrapper import HFTokenizerWrapper

pytest_plugins = ("pytest_asyncio",)

@pytest.mark.asyncio
async def test_typestate_monitor_rust_huggingface_models_mediaplayer() -> None:
"""
Test the working of typestate monitor with Rust repository - mediaplayer
"""
code_language = Language.RUST
params = {
"code_language": code_language,
"repo_url": "https://github.com/LakshyAAAgrawal/MediaPlayer_example/",
"repo_commit": "80cd910cfeb2a05c9e74b69773373c077b00b4c2",
}

device = torch.device('cuda' if is_cuda_available() else 'cpu')

model: transformers.modeling_utils.PreTrainedModel = AutoModelForCausalLM.from_pretrained(
"bigcode/santacoder", trust_remote_code=True
).to(device) #
tokenizer = AutoTokenizer.from_pretrained("bigcode/santacoder")

with create_test_context(params) as context:
lsp = SyncLanguageServer.create(context.config, context.logger, context.source_directory)
filepath = "src/playlist.rs"
# All the communication with the language server must be performed inside the context manager
# The server process is started when the context manager is entered and is terminated when the context manager is exited.
with lsp.start_server():
with lsp.open_file(filepath):
filebuffer = MonitorFileBuffer(lsp.language_server, filepath, (10, 40), (10, 40), code_language, "")
deleted_text = filebuffer.lsp.delete_text_between_positions(
filepath, Position(line=10, character=40), Position(line=12, character=4)
)
assert (
deleted_text
== """reset();
media_player1 = media_player;
"""
)
monitor = DereferencesMonitor(HFTokenizerWrapper(tokenizer), filebuffer)
mgd_logits_processor = MGDLogitsProcessor([monitor], lsp.language_server.server.loop)

with open(str(PurePath(context.source_directory, filepath)), "r") as f:
filecontent = f.read()

pos_idx = TextUtils.get_index_from_line_col(filecontent, 10, 40)
assert filecontent[pos_idx] == "r"

with open(str(PurePath(context.source_directory, "src/media_player.rs")), "r") as f:
classExprTypes = f.read()

prompt = filecontent[:pos_idx] + """<fim-suffix>();
media_player1 = media_player;
}
}<fim-middle>."""
assert prompt[-1] == "."
prompt_tokenized = tokenizer.encode("<fim-prefix>" + classExprTypes + '\n' + prompt, return_tensors="pt")[:, -(2048-512):].to(device)

generated_code_without_mgd = model.generate(
prompt_tokenized, do_sample=False, max_new_tokens=5, top_p=0.95, temperature=0.2,
)

num_gen_tokens = generated_code_without_mgd.shape[-1] - prompt_tokenized.shape[-1]
generated_code_without_mgd = tokenizer.decode(generated_code_without_mgd[0, -num_gen_tokens:])

assert (
generated_code_without_mgd
== """stop();
let media"""
)

# Generate code using santacoder model with the MGD logits processor and greedy decoding
logits_processor = LogitsProcessorList([mgd_logits_processor])
generated_code = model.generate(
prompt_tokenized,
do_sample=False,
max_new_tokens=30,
logits_processor=logits_processor,
early_stopping=True,
)

num_gen_tokens = generated_code.shape[-1] - prompt_tokenized.shape[-1]
generated_code = tokenizer.decode(generated_code[0, -num_gen_tokens:])

assert (
generated_code
== """reset();
let media_player = media_player.set_media_file_path(song);
let media_player = media_"""
)

0 comments on commit 3751934

Please sign in to comment.