speech_recognition.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. #include <opencv2/core.hpp>
  2. #include <opencv2/videoio.hpp>
  3. #include <opencv2/highgui.hpp>
  4. #include <opencv2/imgproc.hpp>
  5. #include <opencv2/dnn.hpp>
  6. #include <iostream>
  7. #include <vector>
  8. #include <string>
  9. #include <unordered_map>
  10. #include <cmath>
  11. #include <random>
  12. #include <numeric>
  13. using namespace cv;
  14. using namespace std;
  15. class FilterbankFeatures {
  16. // Initializes pre-processing class. Default values are the values used by the Jasper
  17. // architecture for pre-processing. For more details, refer to the paper here:
  18. // https://arxiv.org/abs/1904.03288
  19. private:
  20. int sample_rate = 16000;
  21. double window_size = 0.02;
  22. double window_stride = 0.01;
  23. int win_length = static_cast<int>(sample_rate * window_size); // Number of samples in window
  24. int hop_length = static_cast<int>(sample_rate * window_stride); // Number of steps to advance between frames
  25. int n_fft = 512; // Size of window for STFT
  26. // Parameters for filterbanks calculation
  27. int n_filt = 64;
  28. double lowfreq = 0.;
  29. double highfreq = sample_rate / 2;
  30. public:
  31. // Mel filterbanks preparation
  32. double hz_to_mel(double frequencies)
  33. {
  34. //Converts frequencies from hz to mel scale
  35. // Fill in the linear scale
  36. double f_min = 0.0;
  37. double f_sp = 200.0 / 3;
  38. double mels = (frequencies - f_min) / f_sp;
  39. // Fill in the log-scale part
  40. double min_log_hz = 1000.0; // beginning of log region (Hz)
  41. double min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels)
  42. double logstep = std::log(6.4) / 27.0; // step size for log region
  43. if (frequencies >= min_log_hz)
  44. {
  45. mels = min_log_mel + std::log(frequencies / min_log_hz) / logstep;
  46. }
  47. return mels;
  48. }
  49. vector<double> mel_to_hz(vector<double>& mels)
  50. {
  51. // Converts frequencies from mel to hz scale
  52. // Fill in the linear scale
  53. double f_min = 0.0;
  54. double f_sp = 200.0 / 3;
  55. vector<double> freqs;
  56. for (size_t i = 0; i < mels.size(); i++)
  57. {
  58. freqs.push_back(f_min + f_sp * mels[i]);
  59. }
  60. // And now the nonlinear scale
  61. double min_log_hz = 1000.0; // beginning of log region (Hz)
  62. double min_log_mel = (min_log_hz - f_min) / f_sp; // same (Mels)
  63. double logstep = std::log(6.4) / 27.0; // step size for log region
  64. for(size_t i = 0; i < mels.size(); i++)
  65. {
  66. if (mels[i] >= min_log_mel)
  67. {
  68. freqs[i] = min_log_hz * exp(logstep * (mels[i] - min_log_mel));
  69. }
  70. }
  71. return freqs;
  72. }
  73. vector<double> mel_frequencies(int n_mels, double fmin, double fmax)
  74. {
  75. // Calculates n mel frequencies between 2 frequencies
  76. double min_mel = hz_to_mel(fmin);
  77. double max_mel = hz_to_mel(fmax);
  78. vector<double> mels;
  79. double step = (max_mel - min_mel) / (n_mels - 1);
  80. for(double i = min_mel; i < max_mel; i += step)
  81. {
  82. mels.push_back(i);
  83. }
  84. mels.push_back(max_mel);
  85. vector<double> res = mel_to_hz(mels);
  86. return res;
  87. }
  88. vector<vector<double>> mel(int n_mels, double fmin, double fmax)
  89. {
  90. // Generates mel filterbank matrix
  91. double num = 1 + n_fft / 2;
  92. vector<vector<double>> weights(n_mels, vector<double>(static_cast<int>(num), 0.));
  93. // Center freqs of each FFT bin
  94. vector<double> fftfreqs;
  95. double step = (sample_rate / 2) / (num - 1);
  96. for(double i = 0; i <= sample_rate / 2; i += step)
  97. {
  98. fftfreqs.push_back(i);
  99. }
  100. // 'Center freqs' of mel bands - uniformly spaced between limits
  101. vector<double> mel_f = mel_frequencies(n_mels + 2, fmin, fmax);
  102. vector<double> fdiff;
  103. for(size_t i = 1; i < mel_f.size(); ++i)
  104. {
  105. fdiff.push_back(mel_f[i]- mel_f[i - 1]);
  106. }
  107. vector<vector<double>> ramps(mel_f.size(), vector<double>(fftfreqs.size()));
  108. for (size_t i = 0; i < mel_f.size(); ++i)
  109. {
  110. for (size_t j = 0; j < fftfreqs.size(); ++j)
  111. {
  112. ramps[i][j] = mel_f[i] - fftfreqs[j];
  113. }
  114. }
  115. double lower, upper, enorm;
  116. for (int i = 0; i < n_mels; ++i)
  117. {
  118. // using Slaney-style mel which is scaled to be approx constant energy per channel
  119. enorm = 2./(mel_f[i + 2] - mel_f[i]);
  120. for (int j = 0; j < static_cast<int>(num); ++j)
  121. {
  122. // lower and upper slopes for all bins
  123. lower = (-1) * ramps[i][j] / fdiff[i];
  124. upper = ramps[i + 2][j] / fdiff[i + 1];
  125. weights[i][j] = max(0., min(lower, upper)) * enorm;
  126. }
  127. }
  128. return weights;
  129. }
  130. // STFT preparation
  131. vector<double> pad_window_center(vector<double>&data, int size)
  132. {
  133. // Pad the window out to n_fft size
  134. int n = static_cast<int>(data.size());
  135. int lpad = static_cast<int>((size - n) / 2);
  136. vector<double> pad_array;
  137. for(int i = 0; i < lpad; ++i)
  138. {
  139. pad_array.push_back(0.);
  140. }
  141. for(size_t i = 0; i < data.size(); ++i)
  142. {
  143. pad_array.push_back(data[i]);
  144. }
  145. for(int i = 0; i < lpad; ++i)
  146. {
  147. pad_array.push_back(0.);
  148. }
  149. return pad_array;
  150. }
  151. vector<vector<double>> frame(vector<double>& x)
  152. {
  153. // Slices a data array into overlapping frames.
  154. int n_frames = static_cast<int>(1 + (x.size() - n_fft) / hop_length);
  155. vector<vector<double>> new_x(n_fft, vector<double>(n_frames));
  156. for (int i = 0; i < n_fft; ++i)
  157. {
  158. for (int j = 0; j < n_frames; ++j)
  159. {
  160. new_x[i][j] = x[i + j * hop_length];
  161. }
  162. }
  163. return new_x;
  164. }
  165. vector<double> hanning()
  166. {
  167. // https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
  168. vector<double> window_tensor;
  169. for (int j = 1 - win_length; j < win_length; j+=2)
  170. {
  171. window_tensor.push_back(1 - (0.5 * (1 - cos(CV_PI * j / (win_length - 1)))));
  172. }
  173. return window_tensor;
  174. }
  175. vector<vector<double>> stft_power(vector<double>& y)
  176. {
  177. // Short Time Fourier Transform. The STFT represents a signal in the time-frequency
  178. // domain by computing discrete Fourier transforms (DFT) over short overlapping windows.
  179. // https://en.wikipedia.org/wiki/Short-time_Fourier_transform
  180. // Pad the time series so that frames are centered
  181. vector<double> new_y;
  182. int num = int(n_fft / 2);
  183. for (int i = 0; i < num; ++i)
  184. {
  185. new_y.push_back(y[num - i]);
  186. }
  187. for (size_t i = 0; i < y.size(); ++i)
  188. {
  189. new_y.push_back(y[i]);
  190. }
  191. for (size_t i = y.size() - 2; i >= y.size() - num - 1; --i)
  192. {
  193. new_y.push_back(y[i]);
  194. }
  195. // Compute a window function
  196. vector<double> window_tensor = hanning();
  197. // Pad the window out to n_fft size
  198. vector<double> fft_window = pad_window_center(window_tensor, n_fft);
  199. // Window the time series
  200. vector<vector<double>> y_frames = frame(new_y);
  201. // Multiply on fft_window
  202. for (size_t i = 0; i < y_frames.size(); ++i)
  203. {
  204. for (size_t j = 0; j < y_frames[0].size(); ++j)
  205. {
  206. y_frames[i][j] *= fft_window[i];
  207. }
  208. }
  209. // Transpose frames for computing stft
  210. vector<vector<double>> y_frames_transpose(y_frames[0].size(), vector<double>(y_frames.size()));
  211. for (size_t i = 0; i < y_frames[0].size(); ++i)
  212. {
  213. for (size_t j = 0; j < y_frames.size(); ++j)
  214. {
  215. y_frames_transpose[i][j] = y_frames[j][i];
  216. }
  217. }
  218. // Short Time Fourier Transform
  219. // and get power of spectrum
  220. vector<vector<double>> spectrum_power(y_frames_transpose[0].size() / 2 + 1 );
  221. for (size_t i = 0; i < y_frames_transpose.size(); ++i)
  222. {
  223. Mat dstMat;
  224. dft(y_frames_transpose[i], dstMat, DFT_COMPLEX_OUTPUT);
  225. // we need only the first part of the spectrum, the second part is symmetrical
  226. for (int j = 0; j < static_cast<int>(y_frames_transpose[0].size()) / 2 + 1; ++j)
  227. {
  228. double power_re = dstMat.at<double>(2 * j) * dstMat.at<double>(2 * j);
  229. double power_im = dstMat.at<double>(2 * j + 1) * dstMat.at<double>(2 * j + 1);
  230. spectrum_power[j].push_back(power_re + power_im);
  231. }
  232. }
  233. return spectrum_power;
  234. }
  235. Mat calculate_features(vector<double>& x)
  236. {
  237. // Calculates filterbank features matrix.
  238. // Do preemphasis
  239. std::default_random_engine generator;
  240. std::normal_distribution<double> normal_distr(0, 1);
  241. double dither = 1e-5;
  242. for(size_t i = 0; i < x.size(); ++i)
  243. {
  244. x[i] += dither * static_cast<double>(normal_distr(generator));
  245. }
  246. double preemph = 0.97;
  247. for (size_t i = x.size() - 1; i > 0; --i)
  248. {
  249. x[i] -= preemph * x[i-1];
  250. }
  251. // Calculate Short Time Fourier Transform and get power of spectrum
  252. auto spectrum_power = stft_power(x);
  253. vector<vector<double>> filterbanks = mel(n_filt, lowfreq, highfreq);
  254. // Calculate log of multiplication of filterbanks matrix on spectrum_power matrix
  255. vector<vector<double>> x_stft(filterbanks.size(), vector<double>(spectrum_power[0].size(), 0));
  256. for (size_t i = 0; i < filterbanks.size(); ++i)
  257. {
  258. for (size_t j = 0; j < filterbanks[0].size(); ++j)
  259. {
  260. for (size_t k = 0; k < spectrum_power[0].size(); ++k)
  261. {
  262. x_stft[i][k] += filterbanks[i][j] * spectrum_power[j][k];
  263. }
  264. }
  265. for (size_t k = 0; k < spectrum_power[0].size(); ++k)
  266. {
  267. x_stft[i][k] = std::log(x_stft[i][k] + 1e-20);
  268. }
  269. }
  270. // normalize data
  271. auto elments_num = x_stft[0].size();
  272. for(size_t i = 0; i < x_stft.size(); ++i)
  273. {
  274. double x_mean = std::accumulate(x_stft[i].begin(), x_stft[i].end(), 0.) / elments_num; // arithmetic mean
  275. double x_std = 0; // standard deviation
  276. for(size_t j = 0; j < elments_num; ++j)
  277. {
  278. double subtract = x_stft[i][j] - x_mean;
  279. x_std += subtract * subtract;
  280. }
  281. x_std /= elments_num;
  282. x_std = sqrt(x_std) + 1e-10; // make sure x_std is not zero
  283. for(size_t j = 0; j < elments_num; ++j)
  284. {
  285. x_stft[i][j] = (x_stft[i][j] - x_mean) / x_std; // standard score
  286. }
  287. }
  288. Mat calculate_features(static_cast<int>(x_stft.size()), static_cast<int>(x_stft[0].size()), CV_32F);
  289. for(int i = 0; i < calculate_features.size[0]; ++i)
  290. {
  291. for(int j = 0; j < calculate_features.size[1]; ++j)
  292. {
  293. calculate_features.at<float>(i, j) = static_cast<float>(x_stft[i][j]);
  294. }
  295. }
  296. return calculate_features;
  297. }
  298. };
  299. class Decoder {
  300. // Used for decoding the output of jasper model
  301. private:
  302. unordered_map<int, char> labels_map = fillMap();
  303. int blank_id = 28;
  304. public:
  305. unordered_map<int, char> fillMap()
  306. {
  307. vector<char> labels={' ','a','b','c','d','e','f','g','h','i','j','k','l','m','n','o','p'
  308. ,'q','r','s','t','u','v','w','x','y','z','\''};
  309. unordered_map<int, char> map;
  310. for(int i = 0; i < static_cast<int>(labels.size()); ++i)
  311. {
  312. map[i] = labels[i];
  313. }
  314. return map;
  315. }
  316. string decode(Mat& x)
  317. {
  318. // Takes output of Jasper model and performs ctc decoding algorithm to
  319. // remove duplicates and special symbol. Returns prediction
  320. vector<int> prediction;
  321. for(int i = 0; i < x.size[1]; ++i)
  322. {
  323. double maxEl = -1e10;
  324. int ind = 0;
  325. for(int j = 0; j < x.size[2]; ++j)
  326. {
  327. if (maxEl <= x.at<float>(0, i, j))
  328. {
  329. maxEl = x.at<float>(0, i, j);
  330. ind = j;
  331. }
  332. }
  333. prediction.push_back(ind);
  334. }
  335. // CTC decoding procedure
  336. vector<double> decoded_prediction = {};
  337. int previous = blank_id;
  338. for(int i = 0; i < static_cast<int>(prediction.size()); ++i)
  339. {
  340. if (( prediction[i] != previous || previous == blank_id) && prediction[i] != blank_id)
  341. {
  342. decoded_prediction.push_back(prediction[i]);
  343. }
  344. previous = prediction[i];
  345. }
  346. string hypotheses = {};
  347. for(size_t i = 0; i < decoded_prediction.size(); ++i)
  348. {
  349. auto it = labels_map.find(static_cast<char>(decoded_prediction[i]));
  350. if (it != labels_map.end())
  351. hypotheses.push_back(it->second);
  352. }
  353. return hypotheses;
  354. }
  355. };
  356. static string predict(Mat& features, dnn::Net net, Decoder decoder)
  357. {
  358. // Passes the features through the Jasper model and decodes the output to english transcripts.
  359. // expand 2d features matrix to 3d
  360. vector<int> sizes = {1, static_cast<int>(features.size[0]),
  361. static_cast<int>(features.size[1])};
  362. features = features.reshape(0, sizes);
  363. // make prediction
  364. net.setInput(features);
  365. Mat output = net.forward();
  366. // decode output to transcript
  367. auto prediction = decoder.decode(output);
  368. return prediction;
  369. }
  370. static int readAudioFile(vector<double>& inputAudio, string file, int audioStream)
  371. {
  372. VideoCapture cap;
  373. int samplingRate = 16000;
  374. vector<int> params { CAP_PROP_AUDIO_STREAM, audioStream,
  375. CAP_PROP_VIDEO_STREAM, -1,
  376. CAP_PROP_AUDIO_DATA_DEPTH, CV_32F,
  377. CAP_PROP_AUDIO_SAMPLES_PER_SECOND, samplingRate
  378. };
  379. cap.open(file, CAP_ANY, params);
  380. if (!cap.isOpened())
  381. {
  382. cerr << "Error : Can't read audio file: '" << file << "' with audioStream = " << audioStream << endl;
  383. return -1;
  384. }
  385. const int audioBaseIndex = (int)cap.get(CAP_PROP_AUDIO_BASE_INDEX);
  386. vector<double> frameVec;
  387. Mat frame;
  388. for (;;)
  389. {
  390. if (cap.grab())
  391. {
  392. cap.retrieve(frame, audioBaseIndex);
  393. frameVec = frame;
  394. inputAudio.insert(inputAudio.end(), frameVec.begin(), frameVec.end());
  395. }
  396. else
  397. {
  398. break;
  399. }
  400. }
  401. return samplingRate;
  402. }
  403. static int readAudioMicrophone(vector<double>& inputAudio, int microTime)
  404. {
  405. VideoCapture cap;
  406. int samplingRate = 16000;
  407. vector<int> params { CAP_PROP_AUDIO_STREAM, 0,
  408. CAP_PROP_VIDEO_STREAM, -1,
  409. CAP_PROP_AUDIO_DATA_DEPTH, CV_32F,
  410. CAP_PROP_AUDIO_SAMPLES_PER_SECOND, samplingRate
  411. };
  412. cap.open(0, CAP_ANY, params);
  413. if (!cap.isOpened())
  414. {
  415. cerr << "Error: Can't open microphone" << endl;
  416. return -1;
  417. }
  418. const int audioBaseIndex = (int)cap.get(CAP_PROP_AUDIO_BASE_INDEX);
  419. vector<double> frameVec;
  420. Mat frame;
  421. if (microTime <= 0)
  422. {
  423. cerr << "Error: Duration of audio chunk must be > 0" << endl;
  424. return -1;
  425. }
  426. size_t sizeOfData = static_cast<size_t>(microTime * samplingRate);
  427. while (inputAudio.size() < sizeOfData)
  428. {
  429. if (cap.grab())
  430. {
  431. cap.retrieve(frame, audioBaseIndex);
  432. frameVec = frame;
  433. inputAudio.insert(inputAudio.end(), frameVec.begin(), frameVec.end());
  434. }
  435. else
  436. {
  437. cerr << "Error: Grab error" << endl;
  438. break;
  439. }
  440. }
  441. return samplingRate;
  442. }
  443. int main(int argc, char** argv)
  444. {
  445. const String keys =
  446. "{help h usage ? | | This script runs Jasper Speech recognition model }"
  447. "{input_file i | | Path to input audio file. If not specified, microphone input will be used }"
  448. "{audio_duration t | 15 | Duration of audio chunk to be captured from microphone }"
  449. "{audio_stream a | 0 | CAP_PROP_AUDIO_STREAM value }"
  450. "{show_spectrogram s | false | Show a spectrogram of the input audio: true / false / 1 / 0 }"
  451. "{model m | jasper.onnx | Path to the onnx file of Jasper. You can download the converted onnx model "
  452. "from https://drive.google.com/drive/folders/1wLtxyao4ItAg8tt4Sb63zt6qXzhcQoR6?usp=sharing}"
  453. "{backend b | dnn::DNN_BACKEND_DEFAULT | Select a computation backend: "
  454. "dnn::DNN_BACKEND_DEFAULT, "
  455. "dnn::DNN_BACKEND_INFERENCE_ENGINE, "
  456. "dnn::DNN_BACKEND_OPENCV }"
  457. "{target t | dnn::DNN_TARGET_CPU | Select a target device: "
  458. "dnn::DNN_TARGET_CPU, "
  459. "dnn::DNN_TARGET_OPENCL, "
  460. "dnn::DNN_TARGET_OPENCL_FP16 }"
  461. ;
  462. CommandLineParser parser(argc, argv, keys);
  463. if (parser.has("help"))
  464. {
  465. parser.printMessage();
  466. return 0;
  467. }
  468. // Load Network
  469. dnn::Net net = dnn::readNetFromONNX(parser.get<std::string>("model"));
  470. net.setPreferableBackend(parser.get<int>("backend"));
  471. net.setPreferableTarget(parser.get<int>("target"));
  472. // Get audio
  473. vector<double>inputAudio = {};
  474. int samplingRate = 0;
  475. if (parser.has("input_file"))
  476. {
  477. string audio = samples::findFile(parser.get<std::string>("input_file"));
  478. samplingRate = readAudioFile(inputAudio, audio, parser.get<int>("audio_stream"));
  479. }
  480. else
  481. {
  482. samplingRate = readAudioMicrophone(inputAudio, parser.get<int>("audio_duration"));
  483. }
  484. if ((inputAudio.size() == 0) || samplingRate <= 0)
  485. {
  486. cerr << "Error: problems with audio reading, check input arguments" << endl;
  487. return -1;
  488. }
  489. if (inputAudio.size() / samplingRate < 6)
  490. {
  491. cout << "Warning: For predictable network performance duration of audio must exceed 6 sec."
  492. " Audio will be extended with zero samples" << endl;
  493. for(int i = static_cast<int>(inputAudio.size()) - 1; i < samplingRate * 6; ++i)
  494. {
  495. inputAudio.push_back(0);
  496. }
  497. }
  498. // Calculate features
  499. FilterbankFeatures filter;
  500. auto calculated_features = filter.calculate_features(inputAudio);
  501. // Show spectogram if required
  502. if (parser.get<bool>("show_spectrogram") == true)
  503. {
  504. Mat spectogram;
  505. normalize(calculated_features, spectogram, 0, 255, NORM_MINMAX, CV_8U);
  506. applyColorMap(spectogram, spectogram, COLORMAP_INFERNO);
  507. imshow("spectogram", spectogram);
  508. waitKey(0);
  509. }
  510. Decoder decoder;
  511. string prediction = predict(calculated_features, net, decoder);
  512. for( auto &transcript: prediction)
  513. {
  514. cout << transcript;
  515. }
  516. return 0;
  517. }