Skip to content

Commit

Permalink
fix: model combobox value upon provider change
Browse files Browse the repository at this point in the history
  • Loading branch information
carlrobertoh committed Sep 24, 2024
1 parent c42ee67 commit 0e9dba0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public void sync(Conversation conversation) {
state.setSelectedService(provider);
if (project != null) {
project.getMessageBus()
.syncPublisher(ProviderChangeNotifier.getPROVIDER_CHANGE_TOPIC())
.syncPublisher(ProviderChangeNotifier.getTOPIC())
.providerChanged(provider);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ee.carlrobert.codegpt.settings;

import com.intellij.openapi.application.ApplicationManager;
import ee.carlrobert.codegpt.settings.service.ProviderChangeNotifier;
import ee.carlrobert.codegpt.settings.service.ServiceType;

public class GeneralSettingsState {
Expand Down Expand Up @@ -28,5 +30,10 @@ public ServiceType getSelectedService() {

public void setSelectedService(ServiceType selectedService) {
this.selectedService = selectedService;

ApplicationManager.getApplication()
.getMessageBus()
.syncPublisher(ProviderChangeNotifier.getTOPIC())
.providerChanged(selectedService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public ChatToolWindowPanel(
(AttachImageNotifier) filePath -> imageFileAttachmentNotification.show(
Path.of(filePath).getFileName().toString(),
"File path: " + filePath));
messageBusConnection.subscribe(ProviderChangeNotifier.getPROVIDER_CHANGE_TOPIC(),
messageBusConnection.subscribe(ProviderChangeNotifier.getTOPIC(),
(ProviderChangeNotifier) provider -> {
if (provider == ServiceType.CODEGPT) {
var userDetails = CodeGPTKeys.CODEGPT_USER_DETAILS.get(project);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ee.carlrobert.codegpt.Icons;
import ee.carlrobert.codegpt.completions.llama.LlamaModel;
import ee.carlrobert.codegpt.settings.GeneralSettings;
import ee.carlrobert.codegpt.settings.service.ProviderChangeNotifier;
import ee.carlrobert.codegpt.settings.service.ServiceType;
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTAvailableModels;
import ee.carlrobert.codegpt.settings.service.codegpt.CodeGPTModel;
Expand Down Expand Up @@ -63,6 +64,11 @@ public ModelComboBoxAction(
this.onModelChange = onModelChange;
this.availableProviders = availableProviders;
updateTemplatePresentation(selectedProvider);
ApplicationManager.getApplication().getMessageBus()
.connect()
.subscribe(
ProviderChangeNotifier.getTOPIC(),
(ProviderChangeNotifier) this::updateTemplatePresentation);
}

public JComponent createCustomComponent(@NotNull String place) {
Expand Down Expand Up @@ -134,8 +140,8 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta
var googleGroup = DefaultActionGroup.createPopupGroup(() -> "Google (Gemini)");
googleGroup.getTemplatePresentation().setIcon(Icons.Google);
Arrays.stream(GoogleModel.values())
.forEach(model ->
googleGroup.add(createGoogleModelAction(model, presentation)));
.forEach(model ->
googleGroup.add(createGoogleModelAction(model, presentation)));
actionGroup.add(googleGroup);
}
if (availableProviders.contains(LLAMA_CPP)) {
Expand All @@ -154,7 +160,7 @@ private AnAction[] getCodeGPTModelActions(Project project, Presentation presenta
.getState()
.getAvailableModels()
.forEach(model ->
ollamaGroup.add(createOllamaModelAction(model, presentation)));
ollamaGroup.add(createOllamaModelAction(model, presentation)));
actionGroup.add(ollamaGroup);
}

Expand Down Expand Up @@ -240,10 +246,10 @@ private String getSelectedHuggingFace() {
}

private AnAction createModelAction(
ServiceType serviceType,
String label,
Icon icon,
Presentation comboBoxPresentation) {
ServiceType serviceType,
String label,
Icon icon,
Presentation comboBoxPresentation) {
return createModelAction(serviceType, label, icon, comboBoxPresentation, null);
}

Expand Down Expand Up @@ -288,33 +294,33 @@ private void handleModelChange(

private AnAction createCodeGPTModelAction(CodeGPTModel model, Presentation comboBoxPresentation) {
return createModelAction(CODEGPT, model.getName(), model.getIcon(), comboBoxPresentation,
() -> ApplicationManager.getApplication()
.getService(CodeGPTServiceSettings.class)
.getState()
.getChatCompletionSettings()
.setModel(model.getCode()));
() -> ApplicationManager.getApplication()
.getService(CodeGPTServiceSettings.class)
.getState()
.getChatCompletionSettings()
.setModel(model.getCode()));
}

private AnAction createOllamaModelAction(String model, Presentation comboBoxPresentation) {
return createModelAction(OLLAMA, model, Icons.Ollama, comboBoxPresentation,
() -> ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.setModel(model));
() -> ApplicationManager.getApplication()
.getService(OllamaSettings.class)
.getState()
.setModel(model));
}

private AnAction createOpenAIModelAction(
OpenAIChatCompletionModel model,
Presentation comboBoxPresentation) {
OpenAIChatCompletionModel model,
Presentation comboBoxPresentation) {
return createModelAction(OPENAI, model.getDescription(), Icons.OpenAI, comboBoxPresentation,
() -> OpenAISettings.getCurrentState().setModel(model.getCode()));
() -> OpenAISettings.getCurrentState().setModel(model.getCode()));
}

private AnAction createGoogleModelAction(GoogleModel model, Presentation comboBoxPresentation) {
return createModelAction(GOOGLE, model.getDescription(), Icons.Google, comboBoxPresentation,
() -> ApplicationManager.getApplication()
.getService(GoogleSettings.class)
.getState()
.setModel(model.getCode()));
() -> ApplicationManager.getApplication()
.getService(GoogleSettings.class)
.getState()
.setModel(model.getCode()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ interface ProviderChangeNotifier {

companion object {
@JvmStatic
val PROVIDER_CHANGE_TOPIC =
Topic.create("providerChange", ProviderChangeNotifier::class.java)
val TOPIC = Topic.create("providerChange", ProviderChangeNotifier::class.java)
}
}

0 comments on commit 0e9dba0

Please sign in to comment.