chunk_localizer.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. """Chunk localizer to help localize the most relevant chunks in a file.
  2. This is primarily used to localize the most relevant chunks in a file
  3. for a given query (e.g. edit draft produced by the agent).
  4. """
  5. import pylcs
  6. from pydantic import BaseModel
  7. from tree_sitter_languages import get_parser
  8. class Chunk(BaseModel):
  9. text: str
  10. line_range: tuple[int, int] # (start_line, end_line), 1-index, inclusive
  11. normalized_lcs: float | None = None
  12. def visualize(self) -> str:
  13. lines = self.text.split('\n')
  14. assert len(lines) == self.line_range[1] - self.line_range[0] + 1
  15. ret = ''
  16. for i, line in enumerate(lines):
  17. ret += f'{self.line_range[0] + i}|{line}\n'
  18. return ret
  19. def _create_chunks_from_raw_string(content: str, size: int):
  20. lines = content.split('\n')
  21. ret = []
  22. for i in range(0, len(lines), size):
  23. _cur_lines = lines[i : i + size]
  24. ret.append(
  25. Chunk(
  26. text='\n'.join(_cur_lines),
  27. line_range=(i + 1, i + len(_cur_lines)),
  28. )
  29. )
  30. return ret
  31. def create_chunks(
  32. text: str, size: int = 100, language: str | None = None
  33. ) -> list[Chunk]:
  34. try:
  35. parser = get_parser(language) if language is not None else None
  36. except AttributeError:
  37. # print(f"Language {language} not supported. Falling back to raw string.")
  38. parser = None
  39. if parser is None:
  40. # fallback to raw string
  41. return _create_chunks_from_raw_string(text, size)
  42. # TODO: implement tree-sitter chunking
  43. # return _create_chunks_from_tree_sitter(parser.parse(bytes(text, 'utf-8')), max_chunk_lines=size)
  44. raise NotImplementedError('Tree-sitter chunking not implemented yet.')
  45. def normalized_lcs(chunk: str, query: str) -> float:
  46. """Calculate the normalized Longest Common Subsequence (LCS) to compare file chunk with the query (e.g. edit draft).
  47. We normalize Longest Common Subsequence (LCS) by the length of the chunk
  48. to check how **much** of the chunk is covered by the query.
  49. """
  50. if len(chunk) == 0:
  51. return 0.0
  52. _score = pylcs.lcs_sequence_length(chunk, query)
  53. return _score / len(chunk)
  54. def get_top_k_chunk_matches(
  55. text: str, query: str, k: int = 3, max_chunk_size: int = 100
  56. ) -> list[Chunk]:
  57. """Get the top k chunks in the text that match the query.
  58. The query could be a string of draft code edits.
  59. Args:
  60. text: The text to search for the query.
  61. query: The query to search for in the text.
  62. k: The number of top chunks to return.
  63. max_chunk_size: The maximum number of lines in a chunk.
  64. """
  65. raw_chunks = create_chunks(text, max_chunk_size)
  66. chunks_with_lcs: list[Chunk] = [
  67. Chunk(
  68. text=chunk.text,
  69. line_range=chunk.line_range,
  70. normalized_lcs=normalized_lcs(chunk.text, query),
  71. )
  72. for chunk in raw_chunks
  73. ]
  74. sorted_chunks = sorted(
  75. chunks_with_lcs,
  76. key=lambda x: x.normalized_lcs, # type: ignore
  77. reverse=True,
  78. )
  79. return sorted_chunks[:k]