Skip to content

Commit

Permalink
refactor: remove max_tokens configuration and other minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh committed May 13, 2024
1 parent 0b21652 commit 014f26f
Show file tree
Hide file tree
Showing 23 changed files with 219 additions and 259 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ee.carlrobert.codegpt.treesitter;

import org.treesitter.TSLanguage;
import org.treesitter.TSNode;
import org.treesitter.TSParser;
import org.treesitter.TSTree;

Expand All @@ -16,7 +15,7 @@ public CodeCompletionParser(TSLanguage language) {
public String parse(String prefix, String suffix, String output) {
var result = new StringBuilder(output);
while (!result.isEmpty()) {
if (containsSyntaxErrors(prefix + result + suffix)) {
if (containsError(prefix + result + suffix)) {
result.deleteCharAt(result.length() - 1);
} else {
return result.toString();
Expand All @@ -30,21 +29,11 @@ public String parse(String prefix, String suffix, String output) {
return output;
}

private boolean containsSyntaxErrors(String input) {
return containsSyntaxErrors(getTree(input).getRootNode());
}

private boolean containsSyntaxErrors(TSNode node) {
if (node.isMissing() || node.hasError()) {
return true;
}

for (int i = 0; i < node.getChildCount(); i++) {
if (containsSyntaxErrors(node.getChild(i))) {
return true;
}
}
return false;
private boolean containsError(String input) {
var treeString = getTree(input).getRootNode().toString();
return treeString.contains("ERROR")
|| treeString.contains("MISSING \"}\"")
|| treeString.contains("MISSING \")\"");
}

private TSTree getTree(String input) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ public class CodeCompletionParserTest {
@Test
public void shouldGetValidReturnValue() {
var prefix = """
class Main {
public int getRandomNumber() {
return\s""";
class Main {
public int getRandomNumber() {
return\s""";
var suffix = """
}
}""";
}
}""";
var output = """
10;}
}
public int getRandomNumber(int k) {""";
10;}
}
public int getRandomNumber(int k) {""";

var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")
Expand All @@ -31,16 +31,16 @@ public int getRandomNumber() {
@Test
public void shouldGetValidParenthesisValue() {
var prefix = """
class Main {
public int getRandomNumber(int\s""";
class Main {
public int getRandomNumber(int\s""";
var suffix = """
) {
return 10;
}
}""";
) {
return 10;
}
}""";
var output = """
prevNumber) {
if() {""";
prevNumber) {
if() {""";

var parsedResponse = CodeCompletionParserFactory
.getParserForFileExtension("java")
Expand All @@ -49,41 +49,16 @@ class Main {
assertThat(parsedResponse).isEqualTo("prevNumber");
}

@Test
public void shouldHandleFieldDeclaration() {
var prefix = """
class Main {
\t
private i""";
var suffix = """
public int getRandomNumber(int prevNumber) {
return Math.of()
}
}""";
var output = """
nt randomNumber;
\s
public void get() {""";

var result = CodeCompletionParserFactory
.getParserForFileExtension("java")
.parse(prefix, suffix, output);

assertThat(result).isEqualTo("nt randomNumber;");
}

@Test
public void shouldHandleFormalParameters() {
var prefix = """
class Main {
public int getRandomNumber(""";
class Main {
public int getRandomNumber(""";
var suffix = """
) {
return 10;
}
}""";
) {
return 10;
}
}""";
var output = "int prevNumber) }";

var result = CodeCompletionParserFactory
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jsoup = "1.17.2"
jtokkit = "1.0.0"
junit = "5.10.2"
kotlin = "1.9.24"
llm-client = "0.8.3"
llm-client = "0.8.4"
okio = "3.9.0"
tree-sitter = "0.22.5"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public class ConfigurationComponent {
private final JBCheckBox openNewTabCheckBox;
private final JBCheckBox methodNameGenerationCheckBox;
private final JBCheckBox autoFormattingCheckBox;
private final JBCheckBox autocompletionPostProcessingCheckBox;
private final JTextArea systemPromptTextArea;
private final JTextArea commitMessagePromptTextArea;
private final IntegerField maxTokensField;
Expand Down Expand Up @@ -123,6 +124,10 @@ public void changedUpdate(DocumentEvent e) {
autoFormattingCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autoFormatting.label"),
configuration.isAutoFormattingEnabled());
autocompletionPostProcessingCheckBox = new JBCheckBox(
CodeGPTBundle.get("configurationConfigurable.autocompletionPostProcessing.label"),
configuration.isAutocompletionPostProcessingEnabled()
);

mainPanel = FormBuilder.createFormBuilder()
.addComponent(tablePanel)
Expand All @@ -132,6 +137,7 @@ public void changedUpdate(DocumentEvent e) {
.addComponent(openNewTabCheckBox)
.addComponent(methodNameGenerationCheckBox)
.addComponent(autoFormattingCheckBox)
.addComponent(autocompletionPostProcessingCheckBox)
.addVerticalGap(4)
.addComponent(new TitledSeparator(
CodeGPTBundle.get("configurationConfigurable.section.assistant.title")))
Expand Down Expand Up @@ -159,6 +165,7 @@ public ConfigurationState getCurrentFormState() {
state.setCreateNewChatOnEachAction(openNewTabCheckBox.isSelected());
state.setMethodNameGenerationEnabled(methodNameGenerationCheckBox.isSelected());
state.setAutoFormattingEnabled(autoFormattingCheckBox.isSelected());
state.setAutocompletionPostProcessingEnabled(autocompletionPostProcessingCheckBox.isSelected());
return state;
}

Expand All @@ -174,6 +181,8 @@ public void resetForm() {
openNewTabCheckBox.setSelected(configuration.isCreateNewChatOnEachAction());
methodNameGenerationCheckBox.setSelected(configuration.isMethodNameGenerationEnabled());
autoFormattingCheckBox.setSelected(configuration.isAutoFormattingEnabled());
autocompletionPostProcessingCheckBox.setSelected(
configuration.isAutocompletionPostProcessingEnabled());
}

private Map<String, String> getTableData() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ public class ConfigurationState {
private boolean methodNameGenerationEnabled = true;
private boolean captureCompileErrors = true;
private boolean autoFormattingEnabled = true;
private boolean autocompletionPostProcessingEnabled = true;
private Map<String, String> tableData = EditorActionsUtil.DEFAULT_ACTIONS;

public String getSystemPrompt() {
Expand Down Expand Up @@ -118,6 +119,14 @@ public void setAutoFormattingEnabled(boolean autoFormattingEnabled) {
this.autoFormattingEnabled = autoFormattingEnabled;
}

public boolean isAutocompletionPostProcessingEnabled() {
return autocompletionPostProcessingEnabled;
}

public void setAutocompletionPostProcessingEnabled(boolean autocompletionPostProcessingEnabled) {
this.autocompletionPostProcessingEnabled = autocompletionPostProcessingEnabled;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -135,6 +144,7 @@ public boolean equals(Object o) {
&& methodNameGenerationEnabled == that.methodNameGenerationEnabled
&& captureCompileErrors == that.captureCompileErrors
&& autoFormattingEnabled == that.autoFormattingEnabled
&& autocompletionPostProcessingEnabled == that.autocompletionPostProcessingEnabled
&& Objects.equals(systemPrompt, that.systemPrompt)
&& Objects.equals(commitMessagePrompt, that.commitMessagePrompt)
&& Objects.equals(tableData, that.tableData);
Expand All @@ -145,6 +155,6 @@ public int hashCode() {
return Objects.hash(systemPrompt, commitMessagePrompt, maxTokens, temperature,
checkForPluginUpdates, checkForNewScreenshots, createNewChatOnEachAction,
ignoreGitCommitTokenLimit, methodNameGenerationEnabled, captureCompileErrors,
autoFormattingEnabled, tableData);
autoFormattingEnabled, autocompletionPostProcessingEnabled, tableData);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public class LlamaSettingsState {
private double minP = 0.05;
private double repeatPenalty = 1.1;
private boolean codeCompletionsEnabled = true;
private int codeCompletionMaxTokens = 128;

public boolean isUseCustomModel() {
return useCustomModel;
Expand Down Expand Up @@ -187,14 +186,6 @@ public void setCodeCompletionsEnabled(boolean codeCompletionsEnabled) {
this.codeCompletionsEnabled = codeCompletionsEnabled;
}

public int getCodeCompletionMaxTokens() {
return codeCompletionMaxTokens;
}

public void setCodeCompletionMaxTokens(int codeCompletionMaxTokens) {
this.codeCompletionMaxTokens = codeCompletionMaxTokens;
}

private static Integer getRandomAvailablePortOrDefault() {
try (ServerSocket socket = new ServerSocket(0)) {
return socket.getLocalPort();
Expand Down Expand Up @@ -230,8 +221,7 @@ public boolean equals(Object o) {
&& Objects.equals(serverPort, that.serverPort)
&& Objects.equals(additionalParameters, that.additionalParameters)
&& Objects.equals(additionalBuildParameters, that.additionalBuildParameters)
&& codeCompletionsEnabled == that.codeCompletionsEnabled
&& codeCompletionMaxTokens == that.codeCompletionMaxTokens;
&& codeCompletionsEnabled == that.codeCompletionsEnabled;
}

@Override
Expand All @@ -240,6 +230,6 @@ public int hashCode() {
localModelPromptTemplate, remoteModelPromptTemplate, localModelInfillPromptTemplate,
remoteModelInfillPromptTemplate, baseHost, serverPort, contextSize, threads,
additionalParameters, additionalBuildParameters, topK, topP, minP, repeatPenalty,
codeCompletionsEnabled, codeCompletionMaxTokens);
codeCompletionsEnabled);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public LlamaSettingsForm(LlamaSettingsState settings) {
llamaRequestPreferencesForm = new LlamaRequestPreferencesForm(settings);
codeCompletionConfigurationForm = new CodeCompletionConfigurationForm(
settings.isCodeCompletionsEnabled(),
settings.getCodeCompletionMaxTokens(),
null);
init();
}
Expand Down Expand Up @@ -50,9 +49,7 @@ public LlamaSettingsState getCurrentState() {
state.setUseCustomModel(modelPreferencesForm.isUseCustomLlamaModel());
state.setLocalModelPromptTemplate(modelPreferencesForm.getPromptTemplate());
state.setLocalModelInfillPromptTemplate(modelPreferencesForm.getInfillPromptTemplate());

state.setCodeCompletionsEnabled(codeCompletionConfigurationForm.isCodeCompletionsEnabled());
state.setCodeCompletionMaxTokens(codeCompletionConfigurationForm.getMaxTokens());
return state;
}

Expand All @@ -61,7 +58,6 @@ public void resetForm() {
llamaServerPreferencesForm.resetForm(state);
llamaRequestPreferencesForm.resetForm(state);
codeCompletionConfigurationForm.setCodeCompletionsEnabled(state.isCodeCompletionsEnabled());
codeCompletionConfigurationForm.setMaxTokens(state.getCodeCompletionMaxTokens());
}

public LlamaServerPreferencesForm getLlamaServerPreferencesForm() {
Expand Down
Loading

0 comments on commit 014f26f

Please sign in to comment.