Skip to content

Manual decoding_input_ids/prefix for the decoder injection #15393

@azziko

Description

@azziko

Is your feature request related to a problem? Please describe.
Real time inference using offline models with streaming policies often requires directly working with history for more context aware inference. Hence the need for manual decoder_input_ids injection. For example, streamAtt, alignAtt. Also see hlt-mt/simulstream#17 (comment)

Primarily, I would like this feature for Canary-v2 model.

Describe the solution you'd like

I would see this feature best abstracted in a top level API call, such as transcribe. For example, the following snippet is from the SeamlessM4t model: https://github.com/hlt-mt/simulstream/blob/82c75ee77ee39a7f672cef12f75373035f94c73a/simulstream/server/speech_processors/seamless_streamatt.py#L175

For the Canary-v2 purposes, I see one of the solutions to add decoder_input_ids as a field in MultitaskTrasncriptionConfig and then in the _transcribe_forward implementation of the EncDecMultiTask model to replace the decoder_input_ids provided from the config, if any.

Describe alternatives you've considered

Any clean and easy way to do inference and have control over decoder_input_ids without going into the lower level API would suffice. I would like to hear your thought on it. Maybe injection via the prompt could work too.

Additional context

If you find the solution described above as acceptable(and the feature in general), I would be happy to PR the implementation. Thank you

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions