using System.Collections; using System.Collections.Generic; using UnityEngine; using Unity.Sentis; using System.IO; using Newtonsoft.Json; using System.Text; /* * Whisper Inference Code * ====================== * * Put this script on the Main Camera * * In Assets/StreamingAssets put: * * AudioDecoder_Tiny.sentis * AudioEncoder_Tiny.sentis * LogMelSepctro.sentis * vocab.json * * Drag a 30s 16khz mono uncompressed audioclip into the audioClip field. * * Install package com.unity.nuget.newtonsoft-json from packagemanger * Install package com.unity.sentis * */ public class RunWhisper : MonoBehaviour { IWorker decoderEngine, encoderEngine, spectroEngine; const BackendType backend = BackendType.GPUCompute; // Link your audioclip here. Format must be 16Hz mono non-compressed. public AudioClip audioClip; const int maxTokens = 100; //Special tokens const int END_OF_TEXT = 50257; const int START_OF_TRANSCRIPT = 50258; const int ENGLISH = 50259; const int TRANSCRIBE = 50359; const int START_TIME = 50364; Ops ops; ITensorAllocator allocator; int numSamples; float[] data; string[] tokens; int currentToken = 0; int[] outputTokens = new int[maxTokens]; // Used for special character decoding int[] shiftDownDict = new int[256]; TensorFloat encodedAudio; bool transcribe = false; string outputString = ""; void Start() { allocator = new TensorCachingAllocator(); ops = WorkerFactory.CreateOps(backend, allocator); SetupCharacterShifts(); GetTokens(); Model decoder = ModelLoader.Load(Application.streamingAssetsPath + "/AudioDecoder_Tiny.sentis"); Model encoder = ModelLoader.Load(Application.streamingAssetsPath + "/AudioEncoder_Tiny.sentis"); Model spectro = ModelLoader.Load(Application.streamingAssetsPath + "/LogMelSepctro.sentis"); decoderEngine = WorkerFactory.CreateWorker(backend, decoder); encoderEngine = WorkerFactory.CreateWorker(backend, encoder); spectroEngine = WorkerFactory.CreateWorker(backend, spectro); outputTokens[0] = START_OF_TRANSCRIPT; outputTokens[1] = ENGLISH; outputTokens[2] = TRANSCRIBE; outputTokens[3] = START_TIME; currentToken = 3; LoadAudio(); EncodeAudio(); transcribe = true; } void LoadAudio() { if(audioClip.frequency != 16000) { Debug.Log($"The audio clip should have frequency 16kHz. It has frequency {audioClip.frequency / 1000f}kHz"); } numSamples = audioClip.samples; data = new float[numSamples]; audioClip.GetData(data, 0); } void GetTokens() { var jsonText = File.ReadAllText(Application.streamingAssetsPath + "/vocab.json"); var vocab = Newtonsoft.Json.JsonConvert.DeserializeObject>(jsonText); tokens = new string[vocab.Count]; foreach(var item in vocab) { tokens[item.Value] = item.Key; } } void EncodeAudio() { var input = new TensorFloat(new TensorShape(1, numSamples), data); int maxSamples = 30 * 16000; if (numSamples > maxSamples) { Debug.Log("The AudioClip is too long."); return; } // Pad out to 30 seconds at 16khz if necessary var input30seconds = ops.Pad(input, new int[] { 0, 0, 0, 30 * 16000 - numSamples }); spectroEngine.Execute(input30seconds); var spectroOutput = spectroEngine.PeekOutput() as TensorFloat; encoderEngine.Execute(spectroOutput); encodedAudio = encoderEngine.PeekOutput() as TensorFloat; } // Update is called once per frame void Update() { if (transcribe && currentToken < outputTokens.Length - 1) { var tokensSoFar = new TensorInt(new TensorShape(1, outputTokens.Length), outputTokens); var inputs = new Dictionary { {"encoded_audio",encodedAudio }, {"tokens" , tokensSoFar } }; decoderEngine.Execute(inputs); var tokensOut = decoderEngine.PeekOutput() as TensorFloat; var tokensPredictions = ops.ArgMax(tokensOut, 2, false); tokensPredictions.MakeReadable(); int ID = tokensPredictions[currentToken]; currentToken++; outputTokens[currentToken] = ID; if (ID == END_OF_TEXT) { transcribe = false; } else if (ID >= tokens.Length) outputString += $"(time={(ID - START_TIME) * 0.02f})"; else outputString += GetUnicodeText(tokens[ID]); Debug.Log(outputString); } } // Translates encoded special characters to Unicode string GetUnicodeText(string text) { var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text)); return Encoding.UTF8.GetString(bytes); } string ShiftCharacterDown(string text) { string outText = ""; foreach (char letter in text) { outText += ((int)letter <= 256) ? letter : (char)shiftDownDict[(int)(letter - 256)]; } return outText; } void SetupCharacterShifts() { for (int i = 0, n = 0; i < 256; i++) { if (IsWhiteSpace((char)i)) shiftDownDict[n++] = i; } } bool IsWhiteSpace(char c) { return !(('!' <= c && c <= '~') || ('¡' <= c && c <= '¬') || ('®' <= c && c <= 'ÿ')); } private void OnDestroy() { decoderEngine?.Dispose(); encoderEngine?.Dispose(); spectroEngine?.Dispose(); ops?.Dispose(); } }