Skip to content

Commit

Permalink
Use Assist pipeline STT/TTS on Wear OS (#3611)
Browse files Browse the repository at this point in the history
* Use Assist pipeline STT/TTS on Wear OS

 - Update Assist pipeline support on Wear OS to use the pipelines' STT/TTS capabilities, if available and if the app has the required permission
 - Move UrlHandler functions (app) to UrlUtil (common)

* Create a base AssistViewModel for sharing code

 - Creates AssistViewModelBase in common to share Assist tasks that appear in both the main app and watch app

* Keep screen on during voice input to avoid interruption
  • Loading branch information
jpelgrom authored Jun 30, 2023
1 parent 72722bd commit edf6ba5
Show file tree
Hide file tree
Showing 15 changed files with 503 additions and 304 deletions.
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
package io.homeassistant.companion.android.assist

import android.app.Application
import android.content.pm.PackageManager
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateListOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
import androidx.lifecycle.AndroidViewModel
import androidx.lifecycle.viewModelScope
import dagger.hilt.android.lifecycle.HiltViewModel
import io.homeassistant.companion.android.assist.ui.AssistMessage
import io.homeassistant.companion.android.assist.ui.AssistUiPipeline
import io.homeassistant.companion.android.common.assist.AssistViewModelBase
import io.homeassistant.companion.android.common.data.servers.ServerManager
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineError
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineEventType
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineIntentEnd
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineResponse
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineRunStart
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineSttEnd
import io.homeassistant.companion.android.common.data.websocket.impl.entities.AssistPipelineTtsEnd
import io.homeassistant.companion.android.common.util.AudioRecorder
import io.homeassistant.companion.android.common.util.AudioUrlPlayer
import io.homeassistant.companion.android.util.UrlHandler
import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import javax.inject.Inject
import io.homeassistant.companion.android.common.R as commonR
Expand All @@ -32,40 +23,22 @@ import io.homeassistant.companion.android.common.R as commonR
class AssistViewModel @Inject constructor(
val serverManager: ServerManager,
private val audioRecorder: AudioRecorder,
private val audioUrlPlayer: AudioUrlPlayer,
audioUrlPlayer: AudioUrlPlayer,
application: Application
) : AndroidViewModel(application) {
) : AssistViewModelBase(serverManager, audioRecorder, audioUrlPlayer, application) {

companion object {
const val TAG = "AssistViewModel"
}

enum class AssistInputMode {
TEXT,
TEXT_ONLY,
VOICE_INACTIVE,
VOICE_ACTIVE,
BLOCKED
}

private val app = application

private var filteredServerId: Int? = null
private var selectedServerId = ServerManager.SERVER_ID_ACTIVE
private val allPipelines = mutableMapOf<Int, List<AssistPipelineResponse>>()
private var selectedPipeline: AssistPipelineResponse? = null

private var recorderJob: Job? = null
private var recorderQueue: MutableList<ByteArray>? = null
private var recorderAutoStart = true
private var hasMicrophone = true
private var hasPermission = false
private var requestPermission: (() -> Unit)? = null
private var requestSilently = true

private var binaryHandlerId: Int? = null
private var conversationId: String? = null

private val startMessage = AssistMessage(application.getString(commonR.string.assist_how_can_i_assist), isInput = false)
private val _conversation = mutableStateListOf(startMessage)
val conversation: List<AssistMessage> = _conversation
Expand All @@ -79,10 +52,6 @@ class AssistViewModel @Inject constructor(
var inputMode by mutableStateOf<AssistInputMode?>(null)
private set

init {
hasMicrophone = app.packageManager.hasSystemFeature(PackageManager.FEATURE_MICROPHONE)
}

fun onCreate(serverId: Int?, pipelineId: String?, startListening: Boolean?) {
viewModelScope.launch {
serverId?.let {
Expand Down Expand Up @@ -125,6 +94,12 @@ class AssistViewModel @Inject constructor(
}
}

override fun getInput(): AssistInputMode? = inputMode

override fun setInput(inputMode: AssistInputMode) {
this.inputMode = inputMode
}

private suspend fun checkSupport(): Boolean? {
if (!serverManager.isRegistered()) return false
if (!serverManager.integrationRepository(selectedServerId).isHomeAssistantVersionAtLeast(2023, 5, 0)) return false
Expand Down Expand Up @@ -175,8 +150,7 @@ class AssistViewModel @Inject constructor(

_conversation.clear()
_conversation.add(startMessage)
binaryHandlerId = null
conversationId = null
clearPipelineData()
if (hasMicrophone && it.sttEngine != null) {
if (recorderAutoStart && (hasPermission || requestSilently)) {
inputMode = AssistInputMode.VOICE_INACTIVE
Expand Down Expand Up @@ -241,13 +215,7 @@ class AssistViewModel @Inject constructor(
}

if (recording) {
recorderQueue = mutableListOf()
recorderJob = viewModelScope.launch {
audioRecorder.audioBytes.collect {
recorderQueue?.add(it) ?: sendVoiceData(it)
}
}

setupRecorderQueue()
inputMode = AssistInputMode.VOICE_ACTIVE
runAssistPipeline(null)
} else {
Expand All @@ -264,100 +232,20 @@ class AssistViewModel @Inject constructor(
if (!isVoice) _conversation.add(haMessage)
var message = if (isVoice) userMessage else haMessage

var job: Job? = null
job = viewModelScope.launch {
val flow = if (isVoice) {
serverManager.webSocketRepository(selectedServerId).runAssistPipelineForVoice(
sampleRate = AudioRecorder.SAMPLE_RATE,
outputTts = selectedPipeline?.ttsEngine?.isNotBlank() == true,
pipelineId = selectedPipeline?.id,
conversationId = conversationId
runAssistPipelineInternal(
text,
selectedPipeline
) { newMessage, isInput, isError ->
_conversation.indexOf(message).takeIf { pos -> pos >= 0 }?.let { index ->
_conversation[index] = message.copy(
message = newMessage,
isInput = isInput ?: message.isInput,
isError = isError
)
} else {
serverManager.webSocketRepository(selectedServerId).runAssistPipelineForText(
text = text!!,
pipelineId = selectedPipeline?.id,
conversationId = conversationId
)
}

flow?.collect {
when (it.type) {
AssistPipelineEventType.RUN_START -> {
if (!isVoice) return@collect
val data = (it.data as? AssistPipelineRunStart)?.runnerData
binaryHandlerId = data?.get("stt_binary_handler_id") as? Int
}
AssistPipelineEventType.STT_START -> {
viewModelScope.launch {
recorderQueue?.forEach { item ->
sendVoiceData(item)
}
recorderQueue = null
}
}
AssistPipelineEventType.STT_END -> {
stopRecording()
(it.data as? AssistPipelineSttEnd)?.sttOutput?.let { response ->
_conversation.indexOf(message).takeIf { pos -> pos >= 0 }?.let { index ->
_conversation[index] = message.copy(message = response["text"] as String)
}
}
_conversation.add(haMessage)
message = haMessage
}
AssistPipelineEventType.INTENT_END -> {
val data = (it.data as? AssistPipelineIntentEnd)?.intentOutput ?: return@collect
conversationId = data.conversationId
data.response.speech.plain["speech"]?.let { response ->
_conversation.indexOf(message).takeIf { pos -> pos >= 0 }?.let { index ->
_conversation[index] = message.copy(message = response)
}
}
}
AssistPipelineEventType.TTS_END -> {
if (!isVoice) return@collect
val audioPath = (it.data as? AssistPipelineTtsEnd)?.ttsOutput?.url
if (!audioPath.isNullOrBlank()) {
playAudio(audioPath)
}
}
AssistPipelineEventType.RUN_END -> {
stopRecording()
job?.cancel()
}
AssistPipelineEventType.ERROR -> {
val errorMessage = (it.data as? AssistPipelineError)?.message ?: return@collect
_conversation.indexOf(message).takeIf { pos -> pos >= 0 }?.let { index ->
_conversation[index] = message.copy(message = errorMessage, isError = true)
}
stopRecording()
job?.cancel()
}
else -> { /* Do nothing */ }
if (isInput == true) {
_conversation.add(haMessage)
message = haMessage
}
} ?: run {
_conversation.indexOf(message).takeIf { pos -> pos >= 0 }?.let { index ->
_conversation[index] = message.copy(message = app.getString(commonR.string.assist_error), isError = true)
}
stopRecording()
}
}
}

private fun sendVoiceData(data: ByteArray) {
binaryHandlerId?.let {
viewModelScope.launch {
// Launch to prevent blocking the output flow if the network is slow
serverManager.webSocketRepository(selectedServerId).sendVoiceData(it, data)
}
}
}

private fun playAudio(path: String) {
UrlHandler.handle(serverManager.getServer(selectedServerId)?.connection?.getUrl(), path)?.let {
viewModelScope.launch {
audioUrlPlayer.playAudio(it.toString())
}
}
}
Expand Down Expand Up @@ -385,27 +273,4 @@ class AssistViewModel @Inject constructor(
stopRecording()
stopPlayback()
}

private fun stopRecording() {
audioRecorder.stopRecording()
recorderJob?.cancel()
recorderJob = null
if (binaryHandlerId != null) {
viewModelScope.launch {
recorderQueue?.forEach {
sendVoiceData(it)
}
recorderQueue = null
sendVoiceData(byteArrayOf()) // Empty message to indicate end of recording
binaryHandlerId = null
}
} else {
recorderQueue = null
}
if (inputMode == AssistInputMode.VOICE_ACTIVE) {
inputMode = AssistInputMode.VOICE_INACTIVE
}
}

private fun stopPlayback() = audioUrlPlayer.stop()
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import androidx.compose.material.MaterialTheme
import androidx.compose.material.ModalBottomSheetLayout
import androidx.compose.material.ModalBottomSheetValue
import androidx.compose.material.OutlinedButton
import androidx.compose.material.Surface
import androidx.compose.material.Text
import androidx.compose.material.TextField
import androidx.compose.material.icons.Icons
Expand All @@ -61,6 +60,7 @@ import androidx.compose.runtime.saveable.rememberSaveable
import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.draw.scale
import androidx.compose.ui.focus.FocusRequester
import androidx.compose.ui.focus.focusRequester
Expand All @@ -78,7 +78,7 @@ import androidx.compose.ui.unit.sp
import com.mikepenz.iconics.compose.Image
import com.mikepenz.iconics.typeface.library.community.material.CommunityMaterial
import io.homeassistant.companion.android.R
import io.homeassistant.companion.android.assist.AssistViewModel
import io.homeassistant.companion.android.common.assist.AssistViewModelBase
import kotlinx.coroutines.launch
import io.homeassistant.companion.android.common.R as commonR

Expand All @@ -87,7 +87,7 @@ import io.homeassistant.companion.android.common.R as commonR
fun AssistSheetView(
conversation: List<AssistMessage>,
pipelines: List<AssistUiPipeline>,
inputMode: AssistViewModel.AssistInputMode?,
inputMode: AssistViewModelBase.AssistInputMode?,
currentPipeline: AssistUiPipeline?,
fromFrontend: Boolean,
onSelectPipeline: (Int, String) -> Unit,
Expand Down Expand Up @@ -227,7 +227,7 @@ fun AssistSheetHeader(

@Composable
fun AssistSheetControls(
inputMode: AssistViewModel.AssistInputMode?,
inputMode: AssistViewModelBase.AssistInputMode?,
onChangeInput: () -> Unit,
onTextInput: (String) -> Unit,
onMicrophoneInput: () -> Unit
Expand All @@ -237,18 +237,18 @@ fun AssistSheetControls(
return
}

if (inputMode == AssistViewModel.AssistInputMode.BLOCKED) { // No info and not recoverable, no space
if (inputMode == AssistViewModelBase.AssistInputMode.BLOCKED) { // No info and not recoverable, no space
return
}

val focusRequester = remember { FocusRequester() }
LaunchedEffect(inputMode) {
if (inputMode == AssistViewModel.AssistInputMode.TEXT || inputMode == AssistViewModel.AssistInputMode.TEXT_ONLY) {
if (inputMode == AssistViewModelBase.AssistInputMode.TEXT || inputMode == AssistViewModelBase.AssistInputMode.TEXT_ONLY) {
focusRequester.requestFocus()
}
}

if (inputMode == AssistViewModel.AssistInputMode.TEXT || inputMode == AssistViewModel.AssistInputMode.TEXT_ONLY) {
if (inputMode == AssistViewModelBase.AssistInputMode.TEXT || inputMode == AssistViewModelBase.AssistInputMode.TEXT_ONLY) {
var text by rememberSaveable(stateSaver = TextFieldValue.Saver) {
mutableStateOf(TextFieldValue())
}
Expand All @@ -273,13 +273,13 @@ fun AssistSheetControls(
if (text.text.isNotBlank()) {
onTextInput(text.text)
text = TextFieldValue("")
} else if (inputMode != AssistViewModel.AssistInputMode.TEXT_ONLY) {
} else if (inputMode != AssistViewModelBase.AssistInputMode.TEXT_ONLY) {
onChangeInput()
}
},
enabled = (inputMode != AssistViewModel.AssistInputMode.TEXT_ONLY || text.text.isNotBlank())
enabled = (inputMode != AssistViewModelBase.AssistInputMode.TEXT_ONLY || text.text.isNotBlank())
) {
val inputIsSend = text.text.isNotBlank() || inputMode == AssistViewModel.AssistInputMode.TEXT_ONLY
val inputIsSend = text.text.isNotBlank() || inputMode == AssistViewModelBase.AssistInputMode.TEXT_ONLY
Image(
asset = if (inputIsSend) CommunityMaterial.Icon3.cmd_send else CommunityMaterial.Icon3.cmd_microphone,
contentDescription = stringResource(
Expand All @@ -296,7 +296,7 @@ fun AssistSheetControls(
modifier = Modifier.size(64.dp),
contentAlignment = Alignment.Center
) {
val inputIsActive = inputMode == AssistViewModel.AssistInputMode.VOICE_ACTIVE
val inputIsActive = inputMode == AssistViewModelBase.AssistInputMode.VOICE_ACTIVE
if (inputIsActive) {
val transition = rememberInfiniteTransition()
val scale by transition.animateFloat(
Expand All @@ -307,11 +307,12 @@ fun AssistSheetControls(
repeatMode = RepeatMode.Reverse
)
)
Surface(
color = colorResource(commonR.color.colorSpeechText),
modifier = Modifier.size(48.dp).scale(scale),
shape = CircleShape,
content = {}
Box(
modifier = Modifier
.size(48.dp)
.scale(scale)
.background(color = colorResource(commonR.color.colorSpeechText), shape = CircleShape)
.clip(CircleShape)
)
}
OutlinedButton(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import com.google.accompanist.themeadapter.material.MdcTheme
import dagger.hilt.android.AndroidEntryPoint
import io.homeassistant.companion.android.BaseActivity
import io.homeassistant.companion.android.nfc.views.LoadNfcView
import io.homeassistant.companion.android.util.UrlHandler
import io.homeassistant.companion.android.util.UrlUtil
import kotlinx.coroutines.launch
import io.homeassistant.companion.android.common.R as commonR

Expand Down Expand Up @@ -106,7 +106,7 @@ class NfcSetupActivity : BaseActivity() {
// Create new nfc tag
if (!viewModel.nfcEventShouldWrite) {
val url = NFCUtil.extractUrlFromNFCIntent(intent)
val nfcTagId = UrlHandler.splitNfcTagId(url)
val nfcTagId = UrlUtil.splitNfcTagId(url)
if (nfcTagId == null) {
viewModel.onNfcReadEmpty()
} else {
Expand Down
Loading

0 comments on commit edf6ba5

Please sign in to comment.