complex_utils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """Beamformer module."""
  2. from distutils.version import LooseVersion
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import torch
  7. try:
  8. from torch_complex import functional as FC
  9. from torch_complex.tensor import ComplexTensor
  10. except:
  11. print("Please install torch_complex firstly")
  12. EPS = torch.finfo(torch.double).eps
  13. is_torch_1_8_plus = LooseVersion(torch.__version__) >= LooseVersion("1.8.0")
  14. is_torch_1_9_plus = LooseVersion(torch.__version__) >= LooseVersion("1.9.0")
  15. def new_complex_like(
  16. ref: Union[torch.Tensor, ComplexTensor],
  17. real_imag: Tuple[torch.Tensor, torch.Tensor],
  18. ):
  19. if isinstance(ref, ComplexTensor):
  20. return ComplexTensor(*real_imag)
  21. elif is_torch_complex_tensor(ref):
  22. return torch.complex(*real_imag)
  23. else:
  24. raise ValueError(
  25. "Please update your PyTorch version to 1.9+ for complex support."
  26. )
  27. def is_torch_complex_tensor(c):
  28. return (
  29. not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
  30. )
  31. def is_complex(c):
  32. return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
  33. def to_double(c):
  34. if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
  35. return c.to(dtype=torch.complex128)
  36. else:
  37. return c.double()
  38. def to_float(c):
  39. if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
  40. return c.to(dtype=torch.complex64)
  41. else:
  42. return c.float()
  43. def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
  44. if not isinstance(seq, (list, tuple)):
  45. raise TypeError(
  46. "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
  47. "not Tensor"
  48. )
  49. if isinstance(seq[0], ComplexTensor):
  50. return FC.cat(seq, *args, **kwargs)
  51. else:
  52. return torch.cat(seq, *args, **kwargs)
  53. def complex_norm(
  54. c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
  55. ) -> torch.Tensor:
  56. if not is_complex(c):
  57. raise TypeError("Input is not a complex tensor.")
  58. if is_torch_complex_tensor(c):
  59. return torch.norm(c, dim=dim, keepdim=keepdim)
  60. else:
  61. return torch.sqrt(
  62. (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
  63. )
  64. def einsum(equation, *operands):
  65. # NOTE: Do not mix ComplexTensor and torch.complex in the input!
  66. # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
  67. # mixed input with complex and real tensors.
  68. if len(operands) == 1:
  69. if isinstance(operands[0], (tuple, list)):
  70. operands = operands[0]
  71. complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
  72. return complex_module.einsum(equation, *operands)
  73. elif len(operands) != 2:
  74. op0 = operands[0]
  75. same_type = all(op.dtype == op0.dtype for op in operands[1:])
  76. if same_type:
  77. _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
  78. return _einsum(equation, *operands)
  79. else:
  80. raise ValueError("0 or More than 2 operands are not supported.")
  81. a, b = operands
  82. if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
  83. return FC.einsum(equation, a, b)
  84. elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
  85. if not torch.is_complex(a):
  86. o_real = torch.einsum(equation, a, b.real)
  87. o_imag = torch.einsum(equation, a, b.imag)
  88. return torch.complex(o_real, o_imag)
  89. elif not torch.is_complex(b):
  90. o_real = torch.einsum(equation, a.real, b)
  91. o_imag = torch.einsum(equation, a.imag, b)
  92. return torch.complex(o_real, o_imag)
  93. else:
  94. return torch.einsum(equation, a, b)
  95. else:
  96. return torch.einsum(equation, a, b)
  97. def inverse(
  98. c: Union[torch.Tensor, ComplexTensor]
  99. ) -> Union[torch.Tensor, ComplexTensor]:
  100. if isinstance(c, ComplexTensor):
  101. return c.inverse2()
  102. else:
  103. return c.inverse()
  104. def matmul(
  105. a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
  106. ) -> Union[torch.Tensor, ComplexTensor]:
  107. # NOTE: Do not mix ComplexTensor and torch.complex in the input!
  108. # NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
  109. # multiplication between complex and real tensors.
  110. if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
  111. return FC.matmul(a, b)
  112. elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
  113. if not torch.is_complex(a):
  114. o_real = torch.matmul(a, b.real)
  115. o_imag = torch.matmul(a, b.imag)
  116. return torch.complex(o_real, o_imag)
  117. elif not torch.is_complex(b):
  118. o_real = torch.matmul(a.real, b)
  119. o_imag = torch.matmul(a.imag, b)
  120. return torch.complex(o_real, o_imag)
  121. else:
  122. return torch.matmul(a, b)
  123. else:
  124. return torch.matmul(a, b)
  125. def trace(a: Union[torch.Tensor, ComplexTensor]):
  126. # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
  127. # support bacth processing. Use FC.trace() as fallback.
  128. return FC.trace(a)
  129. def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
  130. if isinstance(a, ComplexTensor):
  131. return FC.reverse(a, dim=dim)
  132. else:
  133. return torch.flip(a, dims=(dim,))
  134. def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
  135. """Solve the linear equation ax = b."""
  136. # NOTE: Do not mix ComplexTensor and torch.complex in the input!
  137. # NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
  138. # mixed input with complex and real tensors.
  139. if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
  140. if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
  141. return FC.solve(b, a, return_LU=False)
  142. else:
  143. return matmul(inverse(a), b)
  144. elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
  145. if torch.is_complex(a) and torch.is_complex(b):
  146. return torch.linalg.solve(a, b)
  147. else:
  148. return matmul(inverse(a), b)
  149. else:
  150. if is_torch_1_8_plus:
  151. return torch.linalg.solve(a, b)
  152. else:
  153. return torch.solve(b, a)[0]
  154. def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
  155. if not isinstance(seq, (list, tuple)):
  156. raise TypeError(
  157. "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
  158. "not Tensor"
  159. )
  160. if isinstance(seq[0], ComplexTensor):
  161. return FC.stack(seq, *args, **kwargs)
  162. else:
  163. return torch.stack(seq, *args, **kwargs)