Source code for src.logic_utils

from infer import InferModule, ClauseInferModule
from tensor_encoder import TensorEncoder
from fol.logic import *
from fol.data_utils import DataUtils
from fol.language import DataType


p_ = Predicate('.', 1, [DataType('spec')])
false = Atom(p_, [Const('__F__', dtype=DataType('spec'))])
true = Atom(p_, [Const('__T__', dtype=DataType('spec'))])


[docs]def get_lang(lark_path, lang_base_path, dataset_type, dataset): """Load the language of first-order logic from files. Read the language, clauses, background knowledge from files. Atoms are generated from the language. """ du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path, dataset_type=dataset_type, dataset=dataset) lang = du.load_language() clauses = du.load_clauses(du.base_path + 'clauses.txt', lang) bk_clauses = du.load_clauses(du.base_path + 'bk_clauses.txt', lang) bk = du.load_atoms(du.base_path + 'bk.txt', lang) atoms = generate_atoms(lang) return lang, clauses, bk_clauses, bk, atoms
[docs]def get_searched_clauses(lark_path, lang_base_path, dataset_type, dataset): """Load the language of first-order logic from files. Read the language, clauses, background knowledge from files. Atoms are generated from the language. """ du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path, dataset_type=dataset_type, dataset=dataset) lang = du.load_language() clauses = du.load_clauses(du.base_path + dataset + '/beam_searched.txt', lang) return clauses
def _get_lang(lark_path, lang_base_path, dataset_type, dataset): """Load the language of first-order logic from files. Read the language, clauses, background knowledge from files. Atoms are generated from the language. """ du = DataUtils(lark_path=lark_path, lang_base_path=lang_base_path, dataset_type=dataset_type, dataset=dataset) lang = du.load_language() clauses = du.get_clauses(lang) bk = du.get_bk(lang) atoms = generate_atoms(lang) return lang, clauses, bk, atoms
[docs]def build_infer_module(clauses, bk_clauses, atoms, lang, device, m=3, infer_step=3, train=False): te = TensorEncoder(lang, atoms, clauses, device=device) I = te.encode() if len(bk_clauses) > 0: te_bk = TensorEncoder(lang, atoms, bk_clauses, device=device) I_bk = te_bk.encode() else: te_bk = None I_bk = None ##I_bk = None im = InferModule(I, m=m, infer_step=infer_step, device=device, train=train, I_bk=I_bk) return im
[docs]def build_clause_infer_module(clauses, bk_clauses, atoms, lang, device, m=3, infer_step=3, train=False): te = TensorEncoder(lang, atoms, clauses, device=device) I = te.encode() if len(bk_clauses) > 0: te_bk = TensorEncoder(lang, atoms, bk_clauses, device=device) I_bk = te_bk.encode() else: te_bk = None I_bk = None im = ClauseInferModule(I, m=m, infer_step=infer_step, device=device, train=train, I_bk=I_bk) return im
[docs]def generate_atoms(lang): spec_atoms = [false, true] atoms = [] for pred in lang.preds: dtypes = pred.dtypes consts_list = [lang.get_by_dtype(dtype) for dtype in dtypes] args_list = list(set(itertools.product(*consts_list))) # args_list = lang.get_args_by_pred(pred) args_str_list = [] # args_mem = [] for args in args_list: if len(args) == 1 or len(set(args)) == len(args): # if len(args) == 1 or (args[0] != args[1] and args[0].mode == args[1].mode): # if len(set(args)) == len(args): # if not (str(sorted([str(arg) for arg in args])) in args_str_list): atoms.append(Atom(pred, args)) # args_str_list.append( # str(sorted([str(arg) for arg in args]))) # print('add atom: ', Atom(pred, args)) return spec_atoms + sorted(atoms)
[docs]def generate_bk(lang): atoms = [] for pred in lang.preds: if pred.name in ['diff_color', 'diff_shape']: dtypes = pred.dtypes consts_list = [lang.get_by_dtype(dtype) for dtype in dtypes] args_list = itertools.product(*consts_list) for args in args_list: if len(args) == 1 or (args[0] != args[1] and args[0].mode == args[1].mode): atoms.append(Atom(pred, args)) return atoms
[docs]def get_index_by_predname(pred_str, atoms): for i, atom in enumerate(atoms): if atom.pred.name == pred_str: return i assert 1, pred_str + ' not found.'
[docs]def parse_clauses(lang, clause_strs): du = DataUtils(lang) return [du.parse_clause(c) for c in clause_strs]