-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Description
Is your feature request related to a problem? Please describe.
Real time inference using offline models with streaming policies often requires directly working with history for more context aware inference. Hence the need for manual decoder_input_ids injection. For example, streamAtt, alignAtt. Also see hlt-mt/simulstream#17 (comment)
Primarily, I would like this feature for Canary-v2 model.
Describe the solution you'd like
I would see this feature best abstracted in a top level API call, such as transcribe. For example, the following snippet is from the SeamlessM4t model: https://github.com/hlt-mt/simulstream/blob/82c75ee77ee39a7f672cef12f75373035f94c73a/simulstream/server/speech_processors/seamless_streamatt.py#L175
For the Canary-v2 purposes, I see one of the solutions to add decoder_input_ids as a field in MultitaskTrasncriptionConfig and then in the _transcribe_forward implementation of the EncDecMultiTask model to replace the decoder_input_ids provided from the config, if any.
Describe alternatives you've considered
Any clean and easy way to do inference and have control over decoder_input_ids without going into the lower level API would suffice. I would like to hear your thought on it. Maybe injection via the prompt could work too.
Additional context
If you find the solution described above as acceptable(and the feature in general), I would be happy to PR the implementation. Thank you