Source code for src.fol.data_utils

import os.path

from lark import Lark
from .exp_parser import ExpTree
from .language import Language, DataType
from .logic import Predicate, NeuralPredicate, FuncSymbol, Const


[docs]class DataUtils(object): """Utilities about logic. A class of utilities about first-order logic. Args: dataset_type (str): A dataset type (kandinsky or clevr). dataset (str): A dataset to be used. Attrs: base_path: The base path of the dataset. """ def __init__(self, lark_path, lang_base_path, dataset_type='kandinsky', dataset='twopairs'): self.base_path = lang_base_path + dataset_type + '/' + dataset + '/' with open(lark_path, encoding="utf-8") as grammar: self.lp_atom = Lark(grammar.read(), start="atom") with open(lark_path, encoding="utf-8") as grammar: self.lp_clause = Lark(grammar.read(), start="clause")
[docs] def load_clauses(self, path, lang): """Read lines and parse to Atom objects. """ clauses = [] if os.path.isfile(path): with open(path) as f: for line in f: if line[-1] == '\n': line = line[:-1] tree = self.lp_clause.parse(line) clause = ExpTree(lang).transform(tree) clauses.append(clause) return clauses
[docs] def load_atoms(self, path, lang): """Read lines and parse to Atom objects. """ atoms = [] if os.path.isfile(path): with open(path) as f: for line in f: if line[-1] == '\n': line = line[:-2] else: line = line[:-1] tree = self.lp_atom.parse(line) atom = ExpTree(lang).transform(tree) atoms.append(atom) return atoms
[docs] def load_preds(self, path): f = open(path) lines = f.readlines() preds = [self.parse_pred(line) for line in lines] return preds
[docs] def load_neural_preds(self, path): f = open(path) lines = f.readlines() preds = [self.parse_neural_pred(line) for line in lines] return preds
[docs] def load_consts(self, path): f = open(path) lines = f.readlines() consts = [] for line in lines: consts.extend(self.parse_const(line)) return consts
[docs] def parse_pred(self, line): """Parse string to predicates. """ line = line.replace('\n', '') pred, arity, dtype_names_str = line.split(':') dtype_names = dtype_names_str.split(',') dtypes = [DataType(dt) for dt in dtype_names] assert int(arity) == len( dtypes), 'Invalid arity and dtypes in ' + pred + '.' return Predicate(pred, int(arity), dtypes)
[docs] def parse_neural_pred(self, line): """Parse string to predicates. """ line = line.replace('\n', '') pred, arity, dtype_names_str = line.split(':') dtype_names = dtype_names_str.split(',') dtypes = [DataType(dt) for dt in dtype_names] assert int(arity) == len( dtypes), 'Invalid arity and dtypes in ' + pred + '.' return NeuralPredicate(pred, int(arity), dtypes)
[docs] def parse_funcs(self, line): """Parse string to function symbols. """ funcs = [] for func_arity in line.split(','): func, arity = func_arity.split(':') funcs.append(FuncSymbol(func, int(arity))) return funcs
[docs] def parse_const(self, line): """Parse string to function symbols. """ line = line.replace('\n', '') dtype_name, const_names_str = line.split(':') dtype = DataType(dtype_name) const_names = const_names_str.split(',') return [Const(const_name, dtype) for const_name in const_names]
[docs] def parse_clause(self, clause_str, lang): tree = self.lp_clause.parse(clause_str) return ExpTree(lang).transform(tree)
[docs] def get_clauses(self, lang): return self.load_clauses(self.base_path + 'clauses.txt', lang)
[docs] def get_bk(self, lang): return self.load_atoms(self.base_path + 'bk.txt', lang)
[docs] def load_language(self): """Load language, background knowledge, and clauses from files. """ preds = self.load_preds(self.base_path + 'preds.txt') + \ self.load_neural_preds(self.base_path + 'neural_preds.txt') consts = self.load_consts(self.base_path + 'consts.txt') lang = Language(preds, [], consts) return lang