load_fr_tf.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import numpy as np
  2. np.set_printoptions(threshold=np.inf)
  3. import logging
  4. def load_ckpt(checkpoint_path):
  5. import tensorflow as tf
  6. if tf.__version__.startswith('2'):
  7. import tensorflow.compat.v1 as tf
  8. tf.disable_v2_behavior()
  9. reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
  10. else:
  11. from tensorflow.python import pywrap_tensorflow
  12. reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
  13. var_to_shape_map = reader.get_variable_to_shape_map()
  14. var_dict = dict()
  15. for var_name in sorted(var_to_shape_map):
  16. if "Adam" in var_name:
  17. continue
  18. tensor = reader.get_tensor(var_name)
  19. # print("in ckpt: {}, {}".format(var_name, tensor.shape))
  20. # print(tensor)
  21. var_dict[var_name] = tensor
  22. return var_dict
  23. def load_tf_pb_dict(pb_model):
  24. import tensorflow as tf
  25. if tf.__version__.startswith('2'):
  26. import tensorflow.compat.v1 as tf
  27. tf.disable_v2_behavior()
  28. # import tensorflow_addons as tfa
  29. # from tensorflow_addons.seq2seq.python.ops import beam_search_ops
  30. else:
  31. from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
  32. from tensorflow.python.ops import lookup_ops as lookup
  33. from tensorflow.python.framework import tensor_util
  34. from tensorflow.python.platform import gfile
  35. sess = tf.Session()
  36. with gfile.FastGFile(pb_model, 'rb') as f:
  37. graph_def = tf.GraphDef()
  38. graph_def.ParseFromString(f.read())
  39. sess.graph.as_default()
  40. tf.import_graph_def(graph_def, name='')
  41. var_dict = dict()
  42. for node in sess.graph_def.node:
  43. if node.op == 'Const':
  44. value = tensor_util.MakeNdarray(node.attr['value'].tensor)
  45. if len(value.shape) >= 1:
  46. var_dict[node.name] = value
  47. return var_dict
  48. def load_tf_dict(pb_model):
  49. if "model.ckpt-" in pb_model:
  50. var_dict = load_ckpt(pb_model)
  51. else:
  52. var_dict = load_tf_pb_dict(pb_model)
  53. return var_dict