logic_inference.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. import random
  3. import re
  4. import shutil
  5. from pyke import knowledge_engine
  6. class PykeProgram:
  7. def __init__(
  8. self, logic_program: str, dataset_name='ProntoQA', workspace_mount_path='./'
  9. ) -> None:
  10. self.logic_program = logic_program
  11. self.flag = self.parse_logic_program()
  12. self.dataset_name = dataset_name
  13. self.cache_dir = os.path.join(workspace_mount_path, '.cache_program')
  14. # prepare the files for facts and rules
  15. try:
  16. self.create_fact_file(self.Facts)
  17. self.create_rule_file(self.Rules)
  18. self.flag = True
  19. except Exception:
  20. self.flag = False
  21. self.answer_map = {
  22. 'ProntoQA': self.answer_map_prontoqa,
  23. 'ProofWriter': self.answer_map_proofwriter,
  24. }
  25. def parse_logic_program(self):
  26. keywords = ['Query:', 'Rules:', 'Facts:', 'Predicates:']
  27. program_str = self.logic_program
  28. for keyword in keywords:
  29. try:
  30. program_str, segment_list = self._parse_segment(program_str, keyword)
  31. setattr(self, keyword[:-1], segment_list)
  32. except Exception:
  33. setattr(self, keyword[:-1], None)
  34. return self.validate_program()
  35. def _parse_segment(self, program_str, key_phrase):
  36. remain_program_str, segment = program_str.split(key_phrase)
  37. segment_list = segment.strip().split('\n')
  38. for i in range(len(segment_list)):
  39. segment_list[i] = segment_list[i].split(':::')[0].strip()
  40. return remain_program_str, segment_list
  41. # check if the program is valid; if not, try to fix it
  42. def validate_program(self):
  43. if self.Rules is not None and self.Facts is not None:
  44. if not self.Rules[0] == '' and not self.Facts[0] == '':
  45. return True
  46. # try to fix the program
  47. tmp_rules = []
  48. tmp_facts = []
  49. statements = self.Facts if self.Facts is not None else self.Rules
  50. if statements is None:
  51. return False
  52. for fact in statements:
  53. if fact.find('>>>') >= 0: # this is a rule
  54. tmp_rules.append(fact)
  55. else:
  56. tmp_facts.append(fact)
  57. self.Rules = tmp_rules
  58. self.Facts = tmp_facts
  59. return False
  60. def create_fact_file(self, facts):
  61. with open(os.path.join(self.cache_dir, 'facts.kfb'), 'w') as f:
  62. for fact in facts:
  63. # check for invalid facts
  64. if not fact.find('$x') >= 0:
  65. f.write(fact + '\n')
  66. def create_rule_file(self, rules):
  67. pyke_rules = []
  68. for idx, rule in enumerate(rules):
  69. pyke_rules.append(self.parse_forward_rule(idx + 1, rule))
  70. with open(os.path.join(self.cache_dir, 'rules.krb'), 'w') as f:
  71. f.write('\n\n'.join(pyke_rules))
  72. # example rule: Furry($x, True) && Quite($x, True) >>> White($x, True)
  73. def parse_forward_rule(self, f_index, rule):
  74. premise, conclusion = rule.split('>>>')
  75. premise = premise.strip()
  76. # split the premise into multiple facts if needed
  77. premise = premise.split('&&')
  78. premise_list = [p.strip() for p in premise]
  79. conclusion = conclusion.strip()
  80. # split the conclusion into multiple facts if needed
  81. conclusion = conclusion.split('&&')
  82. conclusion_list = [c.strip() for c in conclusion]
  83. # create the Pyke rule
  84. pyke_rule = f"""fact{f_index}\n\tforeach"""
  85. for p in premise_list:
  86. pyke_rule += f"""\n\t\tfacts.{p}"""
  87. pyke_rule += """\n\tassert"""
  88. for c in conclusion_list:
  89. pyke_rule += f"""\n\t\tfacts.{c}"""
  90. return pyke_rule
  91. """
  92. for example: Is Marvin from Mars?
  93. Query: FromMars(Marvin, $label)
  94. """
  95. def check_specific_predicate(self, subject_name, predicate_name, engine):
  96. results = []
  97. with engine.prove_goal(
  98. f'facts.{predicate_name}({subject_name}, $label)'
  99. ) as gen:
  100. for vars, plan in gen:
  101. results.append(vars['label'])
  102. with engine.prove_goal(
  103. f'rules.{predicate_name}({subject_name}, $label)'
  104. ) as gen:
  105. for vars, plan in gen:
  106. results.append(vars['label'])
  107. if len(results) == 1:
  108. return results[0]
  109. elif len(results) == 2:
  110. return results[0] and results[1]
  111. elif len(results) == 0:
  112. return None
  113. """
  114. Input Example: Metallic(Wren, False)
  115. """
  116. def parse_query(self, query):
  117. pattern = r'(\w+)\(([^,]+),\s*([^)]+)\)'
  118. match = re.match(pattern, query)
  119. if match:
  120. function_name = match.group(1)
  121. arg1 = match.group(2)
  122. arg2 = match.group(3)
  123. arg2 = True if arg2 == 'True' else False
  124. return function_name, arg1, arg2
  125. else:
  126. raise ValueError(f'Invalid query: {query}')
  127. def execute_program(self):
  128. # delete the compiled_krb dir
  129. complied_krb_dir = './models/compiled_krb'
  130. if os.path.exists(complied_krb_dir):
  131. print('removing compiled_krb')
  132. # os.system(f'rm -rf {complied_krb_dir}/*')
  133. shutil.rmtree(complied_krb_dir)
  134. # absolute_path = os.path.abspath(complied_krb_dir)
  135. # print(absolute_path)
  136. try:
  137. engine = knowledge_engine.engine(self.cache_dir)
  138. engine.reset()
  139. engine.activate('rules')
  140. engine.get_kb('facts')
  141. # parse the logic query into pyke query
  142. predicate, subject, value_to_check = self.parse_query(self.Query[0])
  143. result = self.check_specific_predicate(subject, predicate, engine)
  144. answer = self.answer_map[self.dataset_name](result, value_to_check)
  145. except Exception as err:
  146. return None, err
  147. return answer, ''
  148. def answer_mapping(self, answer):
  149. return answer
  150. def answer_map_prontoqa(self, result, value_to_check):
  151. if result == value_to_check:
  152. return 'A'
  153. else:
  154. return 'B'
  155. def answer_map_proofwriter(self, result, value_to_check):
  156. if result is None:
  157. return 'C'
  158. elif result == value_to_check:
  159. return 'A'
  160. else:
  161. return 'B'
  162. class LogicInferenceEngine:
  163. def __init__(self):
  164. self.dataset_name = os.environ.get('DATASET_NAME', 'ProofWriter')
  165. self.workspace_mount_path = '/workspace'
  166. def random_backup(self):
  167. if self.dataset_name == 'ProntoQA':
  168. return random.choice(['A', 'B'])
  169. elif self.dataset_name == 'ProofWriter':
  170. return random.choice(['A', 'B', 'C'])
  171. def safe_execute_program(self, logic_program):
  172. program = PykeProgram(
  173. logic_program, self.dataset_name, self.workspace_mount_path
  174. )
  175. # cannot parse the program
  176. if not program.flag:
  177. answer = self.random_backup()
  178. return answer, 'parsing error', ''
  179. # execute the program
  180. answer, error_message = program.execute_program()
  181. # not executable
  182. if answer is None:
  183. answer = self.random_backup()
  184. return answer, 'execution error', error_message
  185. # successfully executed
  186. answer = program.answer_mapping(answer)
  187. return answer, 'success', ''