From 5543a5089d455708bae3b27e2461ac25e4482da6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 7 Apr 2023 15:57:29 -0300 Subject: [PATCH] Auto-submit the whisper extension transcription --- css/chat.css | 5 +++++ extensions/whisper_stt/script.py | 36 ++++++++++++-------------------- server.py | 2 +- 3 files changed, 19 insertions(+), 24 deletions(-) diff --git a/css/chat.css b/css/chat.css index c8a9d70a..b5102e9a 100644 --- a/css/chat.css +++ b/css/chat.css @@ -36,3 +36,8 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* { .wrap.svelte-6roggh.svelte-6roggh { max-height: 92.5%; } + +/* This is for the microphone button in the whisper extension */ +.sm.svelte-1ipelgc { + width: 100%; +} diff --git a/extensions/whisper_stt/script.py b/extensions/whisper_stt/script.py index 6ef60c57..36169535 100644 --- a/extensions/whisper_stt/script.py +++ b/extensions/whisper_stt/script.py @@ -1,5 +1,6 @@ import gradio as gr import speech_recognition as sr +from modules import shared input_hijack = { 'state': False, @@ -7,7 +8,7 @@ input_hijack = { } -def do_stt(audio, text_state=""): +def do_stt(audio): transcription = "" r = sr.Recognizer() @@ -21,34 +22,23 @@ def do_stt(audio, text_state=""): except sr.RequestError as e: print("Could not request results from Whisper", e) - input_hijack.update({"state": True, "value": [transcription, transcription]}) - - text_state += transcription + " " - return text_state, text_state + return transcription -def update_hijack(val): - input_hijack.update({"state": True, "value": [val, val]}) - return val - - -def auto_transcribe(audio, audio_auto, text_state=""): +def auto_transcribe(audio, auto_submit): if audio is None: return "", "" - if audio_auto: - return do_stt(audio, text_state) - return "", "" + + transcription = do_stt(audio) + if auto_submit: + input_hijack.update({"state": True, "value": [transcription, transcription]}) + + return transcription, None def ui(): - tr_state = gr.State(value="") - output_transcription = gr.Textbox(label="STT-Input", - placeholder="Speech Preview. Click \"Generate\" to send", - interactive=True) - output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state]) - audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True) with gr.Row(): audio = gr.Audio(source="microphone") - audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state]) - transcribe_button = gr.Button(value="Transcribe") - transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state]) + auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=True) + audio.change(fn=auto_transcribe, inputs=[audio, auto_submit], outputs=[shared.gradio['textbox'], audio]) + audio.change(None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}") diff --git a/server.py b/server.py index 1f6a1f32..1436f9f1 100644 --- a/server.py +++ b/server.py @@ -330,7 +330,7 @@ def create_interface(): shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): - shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") with gr.Row(): shared.gradio['Impersonate'] = gr.Button('Impersonate')