extract_embeds.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from transformers import AutoTokenizer, AutoModel, pipeline
  2. import numpy as np
  3. import sys
  4. import os
  5. import torch
  6. from kaldiio import WriteHelper
  7. import re
  8. text_file_json = sys.argv[1]
  9. out_ark = sys.argv[2]
  10. out_scp = sys.argv[3]
  11. out_shape = sys.argv[4]
  12. device = int(sys.argv[5])
  13. model_path = sys.argv[6]
  14. model = AutoModel.from_pretrained(model_path)
  15. tokenizer = AutoTokenizer.from_pretrained(model_path)
  16. extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
  17. with open(text_file_json, 'r') as f:
  18. js = f.readlines()
  19. f_shape = open(out_shape, "w")
  20. with WriteHelper('ark,scp:{},{}'.format(out_ark, out_scp)) as writer:
  21. with torch.no_grad():
  22. for idx, line in enumerate(js):
  23. id, tokens = line.strip().split(" ", 1)
  24. tokens = re.sub(" ", "", tokens.strip())
  25. tokens = ' '.join([j for j in tokens])
  26. token_num = len(tokens.split(" "))
  27. outputs = extractor(tokens)
  28. outputs = np.array(outputs)
  29. embeds = outputs[0, 1:-1, :]
  30. token_num_embeds, dim = embeds.shape
  31. if token_num == token_num_embeds:
  32. writer(id, embeds)
  33. shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
  34. f_shape.write(shape_line)
  35. else:
  36. print("{}, size has changed, {}, {}, {}".format(id, token_num, token_num_embeds, tokens))
  37. f_shape.close()