Source code for src.fol.logic

from abc import ABC, abstractmethod

import itertools


[docs]def flatten(x): return [z for y in x for z in (
flatten(y) if hasattr(y, '__iter__') and not isinstance(y, str) else (y,))]
[docs]class Term(ABC): """Terms in first-order logic. An abstract class of terms in first-oder logic. Attributes: name (str): Name of the term. dtype (datatype): Data type of the term. """ @abstractmethod def __repr__(self, level=0): pass @abstractmethod def __str__(self): pass @abstractmethod def __eq__(self, other): pass @abstractmethod def __hash__(self): pass
[docs] @abstractmethod def all_vars(self): pass
[docs] @abstractmethod def all_consts(self): pass
[docs] @abstractmethod def all_funcs(self): pass
[docs] @abstractmethod def max_depth(self): pass
[docs] @abstractmethod def min_depth(self): pass
[docs] @abstractmethod def size(self): pass
[docs] @abstractmethod def is_var(self): pass
[docs]class Const(Term): """Constants in first-order logic. A class of constants in first-oder logic. Attributes: name (str): Name of the term. dtype (datatype): Data type of the term. """ def __init__(self, name, dtype=None): self.name = name self.dtype = dtype def __repr__(self, level=0): return self.name def __str__(self): return self.name def __eq__(self, other): return type(other) == Const and self.name == other.name def __hash__(self): return hash(self.__str__()) def __lt__(self, other): return self.__str__() < other.__str__()
[docs] def head(self): return self
[docs] def subs(self, target_var, const): return self
[docs] def to_list(self): return [self]
[docs] def get_ith_term(self, i): assert i == 0, 'Invalid ith term for constant ' + str(self) return self
[docs] def all_vars(self): return []
[docs] def all_consts(self): return [self]
[docs] def all_funcs(self): return []
[docs] def max_depth(self): return 0
[docs] def min_depth(self): return 0
[docs] def size(self): return 1
[docs] def is_var(self): return 0
[docs]class Var(Term): """Variables in first-order logic. A class of variable in first-oder logic. Attributes: name (str): Name of the variable. """ def __init__(self, name): self.name = name def __repr__(self, level=0): # ret = "\t"*level+repr(self.name)+"\n" ret = self.name return ret def __str__(self): return self.name def __eq__(self, other): return type(other) == Var and self.name == other.name def __hash__(self): return hash(self.__str__()) def __lt__(self, other): return self.__str__() < other.__str__()
[docs] def head(self): return self
[docs] def subs(self, target_var, const): if self.name == target_var.name: return const else: return self
[docs] def to_list(self): return [self]
[docs] def get_ith_term(self, i): assert i == 0, 'Invalid ith term for constant ' + str(self) return self
[docs] def all_vars(self): return [self]
[docs] def all_consts(self): return []
[docs] def all_funcs(self): return []
[docs] def max_depth(self): return 0
[docs] def min_depth(self): return 0
[docs] def size(self): return 1
[docs] def is_var(self): return 1
[docs]class FuncSymbol(): """Function symbols in first-order logic. A class of function symbols in first-oder logic. Attributes: name (str): Name of the function. """ def __init__(self, name, arity): self.name = name self.arity = arity def __str__(self): return self.name def __repr__(self): return self.name def __eq__(self, other): return self.name == other.name and self.arity == other.arity def __lt__(self, other): return self.__str__() < other.__str__()
[docs]class FuncTerm(Term): """Term with a function symbol f(t_1, ..., t_n) A class of terms that cosist of a function symbol in first-oder logic. Attributes: func_symbol (FuncSymbol): A function symbolc in the term. args (List[Term]): arguments for the function symbol. """ def __init__(self, func_symbol, args): assert func_symbol.arity == len( args), 'Invalid arguments for function symbol ' + func_symbol.name self.func_symbol = func_symbol self.args = args def __str__(self): s = self.func_symbol.name + '(' for arg in self.args: s += arg.__str__() + ',' s = s[0:-1] s += ')' return s def __lt__(self, other): return self.__str__() < other.__str__() def __repr__(self, level=0): return self.__str__() def __eq__(self, other): if type(other) == FuncTerm: if self.func_symbol != other.func_symbol: return False for i in range(len(self.args)): if not self.args[i] == other.args[i]: return False return True else: return False
[docs] def head(self): return self.func_symbol
[docs] def pre_order(self, i): if i == 0: return self.func_symbol else: return self.pre_order(i-1)
[docs] def get_ith_symbol(self, i): return self.to_list()[i]
[docs] def get_ith_term(self, i): index = [0] result = [Term()] def _loop(x, i): nonlocal index, result if i == index[0]: result[0] = x else: if type(x) == FuncTerm: for term in x.args: index[0] += 1 _loop(term, i) _loop(self, i) return result[0]
[docs] def to_list(self): ls = [] def _to_list(x): nonlocal ls if type(x) == FuncTerm: ls.append(x.func_symbol) for term in x.args: _to_list(term) else: # const or var ls.append(x) _to_list(self) return ls
[docs] def subs(self, target_var, const): self.args = [arg.subs(target_var, const) for arg in self.args] return self
[docs] def all_vars(self): var_list = [] for arg in self.args: var_list += arg.all_vars() return var_list
[docs] def all_consts(self): const_list = [] for arg in self.args: const_list += arg.all_consts() return const_list
[docs] def all_funcs(self): func_list = [] for arg in self.args: func_list += arg.all_funcs() return [self.func_symbol] + func_list
[docs] def max_depth(self): arg_depth = max([arg.max_depth() for arg in self.args]) return arg_depth+1
[docs] def min_depth(self): arg_depth = min([arg.min_depth() for arg in self.args]) return arg_depth+1
[docs] def size(self): size = 1 for arg in self.args: size += arg.size() return size
[docs] def is_var(self): return 0
[docs]class Predicate(): """Predicats in first-order logic. A class of predicates in first-order logic. Attributes: name (str): A name of the predicate. arity (int): The arity of the predicate. dtypes (List[DataTypes]): The data types of the arguments for the predicate. """ def __init__(self, name, arity, dtypes): self.name = name self.arity = arity self.dtypes = dtypes # mode = List[dtype] def __str__(self): # return self.name return self.name + '/' + str(self.arity) + '/' + str(self.dtypes) def __hash__(self): return hash(self.__str__()) def __repr__(self): return self.__str__() def __eq__(self, other): if type(other) == Predicate: return self.name == other.name else: return False def __lt__(self, other): return self.__str__() < other.__str__()
[docs]class NeuralPredicate(Predicate): """Neural predicats. A class of neural predicates, which are associated with a differentiable function. Attributes: name (str): A name of the predicate. arity (int): The arity of the predicate. dtypes (List[DataTypes]): The data types of the arguments for the predicate. """ def __init__(self, name, arity, dtypes): super(NeuralPredicate, self).__init__(name, arity, dtypes) self.name = name self.arity = arity self.dtypes = dtypes def __str__(self): return self.name + '/' + str(self.arity) + '/' + str(self.dtypes) def __hash__(self): return hash(self.__str__()) def __repr__(self): return self.__str__() def __eq__(self, other): return type(other) == NeuralPredicate and self.name == other.name def __lt__(self, other): return self.__str__() < other.__str__()
[docs]class Atom(object): """Atoms in first-oder logic. A class of atoms: p(t1, ..., tn) Attributes: pred (Predicate): A predicate of the atom. terms (List[Term]): The terms for the atoms. """ def __init__(self, pred, terms): assert pred.arity == len( terms), 'Invalid arguments for predicate symbol ' + pred.name self.pred = pred self.terms = terms self.neg_state = False def __eq__(self, other): if other == None: return False if self.pred == other.pred: for i in range(len(self.terms)): if not self.terms[i] == other.terms[i]: return False return True else: return False def __str__(self): s = self.pred.name + '(' for arg in self.terms: s += arg.__str__() + ',' s = s[0:-1] s += ')' return s def __hash__(self): return hash(self.__str__()) def __repr__(self): return self.__str__() def __lt__(self, other): """comparison < """ return self.__str__() < other.__str__() def __gt__(self, other): """comparison > """ return self.__str__() < other.__str__()
[docs] def subs(self, target_var, const): self.terms = [term.subs(target_var, const) for term in self.terms]
[docs] def neg(self): self.neg_state = not self.neg_state
[docs] def all_vars(self): var_list = [] for term in self.terms: # var_list.append(term.all_vars()) var_list += term.all_vars() return var_list
[docs] def all_consts(self): const_list = [] for term in self.terms: const_list += term.all_consts() return const_list
[docs] def all_funcs(self): func_list = [] for term in self.terms: func_list += term.all_funcs() return func_list
[docs] def max_depth(self): return max([term.max_depth() for term in self.terms])
[docs] def min_depth(self): return min([term.min_depth() for term in self.terms])
[docs] def size(self): size = 0 for term in self.terms: size += term.size() return size
[docs] def get_terms_by_dtype(self, dtype): """Return terms that have type of dtype. Returns: (list(Term)) """ result = [] for i, term in enumerate(self.terms): if self.pred.dtypes[i] == dtype: #print( self.pred.dtypes[i], dtype, self.pred.dtypes[i] == dtype) result.append(term) return result
[docs]class Clause(object): """Clauses in first-oder logic. A class of clauses in first-order logic: A :- B1, ..., Bn. Attributes: head (Atom): The head atom. body (List[Atom]): The atoms for the body. """ def __init__(self, head, body): self.head = head self.body = sorted(body) #self.body = body #print(self) ###self._rename() #print(self) def __str__(self): head_str = self.head.__str__() body_str = "" for bi in self.body: body_str += bi.__str__() body_str += ',' body_str = body_str[0:-1] body_str += '.' return head_str + ':-' + body_str def __repr__(self): return self.__str__() def __eq__(self, other): #return self._id_str() == other._id_str() return self.head == other.head and set(self.body) == set(other.body) #return self.__str__() == other.__str__() def __hash__(self): return hash(self.__str__()) def __lt__(self, other): return self.__str__() < other.__str__() def __gt__(self, other): return self.__str__() > other.__str__() def _rename(self): """Renaming variables to compute the equality. e.g. p(O1,O2):-. == p(O2,O3):-. == p(__X1__,__X2__):-. """ atoms = [self.head] + self.body vars = self.all_vars() id_vars = [Var("_X" + str(i) + "_") for i in range(len(vars))] head_terms = [] for term in self.head.terms: if term.is_var(): head_terms.append(id_vars[vars.index(term)]) else: head_terms.append(term) head_ = Atom(self.head.pred, head_terms) body_ = [] for bi in self.body: bi_terms = [] for term in bi.terms: if term.is_var(): bi_terms.append(id_vars[vars.index(term)]) else: bi_terms.append(term) bi_atom = Atom(bi.pred, bi_terms) body_.append(bi_atom) self.head = head_ self.body = sorted(body_) def _id_str(self): """Renaming variables to compute the equality. e.g. p(O1,O2):-. == p(O2,O3):-. == p(__X1__,__X2__):-. """ atoms = [self.head] + self.body vars = list(set(self.all_vars())) id_vars = [Var("__X" + str(i) + "__") for i in range(len(vars))] head_terms = [] for term in self.head.terms: if term.is_var(): head_terms.append(id_vars[vars.index(term)]) else: head_terms.append(term) head_ = Atom(self.head.pred, head_terms) body_ = [] for bi in self.body: bi_terms = [] for term in bi.terms: if term.is_var(): bi_terms.append(id_vars[vars.index(term)]) else: bi_terms.append(term) bi_atom = Atom(bi.pred, bi_terms) body_.append(bi_atom) # to str head_str = head_.__str__() body_str = "" for bi in body_: body_str += bi.__str__() body_str += ',' body_str = body_str[0:-1] body_str += '.' return head_str + ':-' + body_str
[docs] def is_tautology(self): return len(self.body) == 1 and self.body[0] == self.head
[docs] def is_duplicate(self): if len(self.body) >= 2: es = self.body return es == [es[0]] * len(es) if es else False return False
[docs] def subs(self, target_var, const): if type(self.head) == Atom: self.head.subs(target_var, const) for bi in self.body: bi.subs(target_var, const)
[docs] def all_vars(self): var_list = [] var_list += self.head.all_vars() for bi in self.body: var_list += bi.all_vars() var_list = flatten(var_list) # remove duplication result = [] for v in var_list: if not v in result: result.append(v) return result
[docs] def all_vars_by_dtype(self, dtype): """Get all variables in the clause that has given data type. Returns: list(Var) """ atoms = [self.head] + self.body result = [] for atom in atoms: terms = atom.get_terms_by_dtype(dtype) vars = [t for t in terms if t.is_var()] result.extend(vars) return sorted(list(set(result)))
[docs] def count_by_predicate(self, pred): atoms = [self.head] + self.body n = 0 for atom in atoms: if pred == atom.pred: n += 1 return n
[docs] def all_consts(self): const_list = [] const_list += self.head.all_consts() for bi in self.body: const_list += bi.all_consts() const_list = flatten(const_list) return const_list
[docs] def all_funcs(self): func_list = [] func_list += self.head.all_funcs() for bi in self.body: func_list += bi.all_funcs() func_list = flatten(func_list) return func_list
[docs] def max_depth(self): depth_list = [self.head.max_depth()] for b in self.body: depth_list.append(b.max_depth()) return max(depth_list)
[docs] def min_depth(self): depth_list = [self.head.min_depth()] for b in self.body: depth_list.append(b.min_depth()) return min(depth_list)
[docs] def size(self): size = self.head.size() for bi in self.body: size += bi.size() return size