chunk_localizer.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """Chunk localizer to help localize the most relevant chunks in a file.
  2. This is primarily used to localize the most relevant chunks in a file
  3. for a given query (e.g. edit draft produced by the agent).
  4. """
  5. import pylcs
  6. from pydantic import BaseModel
  7. from tree_sitter_languages import get_parser
  8. from openhands.core.logger import openhands_logger as logger
  9. class Chunk(BaseModel):
  10. text: str
  11. line_range: tuple[int, int] # (start_line, end_line), 1-index, inclusive
  12. normalized_lcs: float | None = None
  13. def visualize(self) -> str:
  14. lines = self.text.split('\n')
  15. assert len(lines) == self.line_range[1] - self.line_range[0] + 1
  16. ret = ''
  17. for i, line in enumerate(lines):
  18. ret += f'{self.line_range[0] + i}|{line}\n'
  19. return ret
  20. def _create_chunks_from_raw_string(content: str, size: int):
  21. lines = content.split('\n')
  22. ret = []
  23. for i in range(0, len(lines), size):
  24. _cur_lines = lines[i : i + size]
  25. ret.append(
  26. Chunk(
  27. text='\n'.join(_cur_lines),
  28. line_range=(i + 1, i + len(_cur_lines)),
  29. )
  30. )
  31. return ret
  32. def create_chunks(
  33. text: str, size: int = 100, language: str | None = None
  34. ) -> list[Chunk]:
  35. try:
  36. parser = get_parser(language) if language is not None else None
  37. except AttributeError:
  38. logger.debug(f'Language {language} not supported. Falling back to raw string.')
  39. parser = None
  40. if parser is None:
  41. # fallback to raw string
  42. return _create_chunks_from_raw_string(text, size)
  43. # TODO: implement tree-sitter chunking
  44. # return _create_chunks_from_tree_sitter(parser.parse(bytes(text, 'utf-8')), max_chunk_lines=size)
  45. raise NotImplementedError('Tree-sitter chunking not implemented yet.')
  46. def normalized_lcs(chunk: str, query: str) -> float:
  47. """Calculate the normalized Longest Common Subsequence (LCS) to compare file chunk with the query (e.g. edit draft).
  48. We normalize Longest Common Subsequence (LCS) by the length of the chunk
  49. to check how **much** of the chunk is covered by the query.
  50. """
  51. if len(chunk) == 0:
  52. return 0.0
  53. _score = pylcs.lcs_sequence_length(chunk, query)
  54. return _score / len(chunk)
  55. def get_top_k_chunk_matches(
  56. text: str, query: str, k: int = 3, max_chunk_size: int = 100
  57. ) -> list[Chunk]:
  58. """Get the top k chunks in the text that match the query.
  59. The query could be a string of draft code edits.
  60. Args:
  61. text: The text to search for the query.
  62. query: The query to search for in the text.
  63. k: The number of top chunks to return.
  64. max_chunk_size: The maximum number of lines in a chunk.
  65. """
  66. raw_chunks = create_chunks(text, max_chunk_size)
  67. chunks_with_lcs: list[Chunk] = [
  68. Chunk(
  69. text=chunk.text,
  70. line_range=chunk.line_range,
  71. normalized_lcs=normalized_lcs(chunk.text, query),
  72. )
  73. for chunk in raw_chunks
  74. ]
  75. sorted_chunks = sorted(
  76. chunks_with_lcs,
  77. key=lambda x: x.normalized_lcs, # type: ignore
  78. reverse=True,
  79. )
  80. return sorted_chunks[:k]