from abc import ABC, abstractmethod
import itertools
[docs]def flatten(x): return [z for y in x for z in (
flatten(y) if hasattr(y, '__iter__') and not isinstance(y, str) else (y,))]
[docs]class Term(ABC):
"""Terms in first-order logic.
An abstract class of terms in first-oder logic.
Attributes:
name (str): Name of the term.
dtype (datatype): Data type of the term.
"""
@abstractmethod
def __repr__(self, level=0):
pass
@abstractmethod
def __str__(self):
pass
@abstractmethod
def __eq__(self, other):
pass
@abstractmethod
def __hash__(self):
pass
[docs] @abstractmethod
def all_vars(self):
pass
[docs] @abstractmethod
def all_consts(self):
pass
[docs] @abstractmethod
def all_funcs(self):
pass
[docs] @abstractmethod
def max_depth(self):
pass
[docs] @abstractmethod
def min_depth(self):
pass
[docs] @abstractmethod
def size(self):
pass
[docs] @abstractmethod
def is_var(self):
pass
[docs]class Const(Term):
"""Constants in first-order logic.
A class of constants in first-oder logic.
Attributes:
name (str): Name of the term.
dtype (datatype): Data type of the term.
"""
def __init__(self, name, dtype=None):
self.name = name
self.dtype = dtype
def __repr__(self, level=0):
return self.name
def __str__(self):
return self.name
def __eq__(self, other):
return type(other) == Const and self.name == other.name
def __hash__(self):
return hash(self.__str__())
def __lt__(self, other):
return self.__str__() < other.__str__()
[docs] def head(self):
return self
[docs] def subs(self, target_var, const):
return self
[docs] def to_list(self):
return [self]
[docs] def get_ith_term(self, i):
assert i == 0, 'Invalid ith term for constant ' + str(self)
return self
[docs] def all_vars(self):
return []
[docs] def all_consts(self):
return [self]
[docs] def all_funcs(self):
return []
[docs] def max_depth(self):
return 0
[docs] def min_depth(self):
return 0
[docs] def size(self):
return 1
[docs] def is_var(self):
return 0
[docs]class Var(Term):
"""Variables in first-order logic.
A class of variable in first-oder logic.
Attributes:
name (str): Name of the variable.
"""
def __init__(self, name):
self.name = name
def __repr__(self, level=0):
# ret = "\t"*level+repr(self.name)+"\n"
ret = self.name
return ret
def __str__(self):
return self.name
def __eq__(self, other):
return type(other) == Var and self.name == other.name
def __hash__(self):
return hash(self.__str__())
def __lt__(self, other):
return self.__str__() < other.__str__()
[docs] def head(self):
return self
[docs] def subs(self, target_var, const):
if self.name == target_var.name:
return const
else:
return self
[docs] def to_list(self):
return [self]
[docs] def get_ith_term(self, i):
assert i == 0, 'Invalid ith term for constant ' + str(self)
return self
[docs] def all_vars(self):
return [self]
[docs] def all_consts(self):
return []
[docs] def all_funcs(self):
return []
[docs] def max_depth(self):
return 0
[docs] def min_depth(self):
return 0
[docs] def size(self):
return 1
[docs] def is_var(self):
return 1
[docs]class FuncSymbol():
"""Function symbols in first-order logic.
A class of function symbols in first-oder logic.
Attributes:
name (str): Name of the function.
"""
def __init__(self, name, arity):
self.name = name
self.arity = arity
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return self.name == other.name and self.arity == other.arity
def __lt__(self, other):
return self.__str__() < other.__str__()
[docs]class FuncTerm(Term):
"""Term with a function symbol f(t_1, ..., t_n)
A class of terms that cosist of a function symbol in first-oder logic.
Attributes:
func_symbol (FuncSymbol): A function symbolc in the term.
args (List[Term]): arguments for the function symbol.
"""
def __init__(self, func_symbol, args):
assert func_symbol.arity == len(
args), 'Invalid arguments for function symbol ' + func_symbol.name
self.func_symbol = func_symbol
self.args = args
def __str__(self):
s = self.func_symbol.name + '('
for arg in self.args:
s += arg.__str__() + ','
s = s[0:-1]
s += ')'
return s
def __lt__(self, other):
return self.__str__() < other.__str__()
def __repr__(self, level=0):
return self.__str__()
def __eq__(self, other):
if type(other) == FuncTerm:
if self.func_symbol != other.func_symbol:
return False
for i in range(len(self.args)):
if not self.args[i] == other.args[i]:
return False
return True
else:
return False
[docs] def head(self):
return self.func_symbol
[docs] def pre_order(self, i):
if i == 0:
return self.func_symbol
else:
return self.pre_order(i-1)
[docs] def get_ith_symbol(self, i):
return self.to_list()[i]
[docs] def get_ith_term(self, i):
index = [0]
result = [Term()]
def _loop(x, i):
nonlocal index, result
if i == index[0]:
result[0] = x
else:
if type(x) == FuncTerm:
for term in x.args:
index[0] += 1
_loop(term, i)
_loop(self, i)
return result[0]
[docs] def to_list(self):
ls = []
def _to_list(x):
nonlocal ls
if type(x) == FuncTerm:
ls.append(x.func_symbol)
for term in x.args:
_to_list(term)
else:
# const or var
ls.append(x)
_to_list(self)
return ls
[docs] def subs(self, target_var, const):
self.args = [arg.subs(target_var, const) for arg in self.args]
return self
[docs] def all_vars(self):
var_list = []
for arg in self.args:
var_list += arg.all_vars()
return var_list
[docs] def all_consts(self):
const_list = []
for arg in self.args:
const_list += arg.all_consts()
return const_list
[docs] def all_funcs(self):
func_list = []
for arg in self.args:
func_list += arg.all_funcs()
return [self.func_symbol] + func_list
[docs] def max_depth(self):
arg_depth = max([arg.max_depth() for arg in self.args])
return arg_depth+1
[docs] def min_depth(self):
arg_depth = min([arg.min_depth() for arg in self.args])
return arg_depth+1
[docs] def size(self):
size = 1
for arg in self.args:
size += arg.size()
return size
[docs] def is_var(self):
return 0
[docs]class Predicate():
"""Predicats in first-order logic.
A class of predicates in first-order logic.
Attributes:
name (str): A name of the predicate.
arity (int): The arity of the predicate.
dtypes (List[DataTypes]): The data types of the arguments for the predicate.
"""
def __init__(self, name, arity, dtypes):
self.name = name
self.arity = arity
self.dtypes = dtypes # mode = List[dtype]
def __str__(self):
# return self.name
return self.name + '/' + str(self.arity) + '/' + str(self.dtypes)
def __hash__(self):
return hash(self.__str__())
def __repr__(self):
return self.__str__()
def __eq__(self, other):
if type(other) == Predicate:
return self.name == other.name
else:
return False
def __lt__(self, other):
return self.__str__() < other.__str__()
[docs]class NeuralPredicate(Predicate):
"""Neural predicats.
A class of neural predicates, which are associated with a differentiable function.
Attributes:
name (str): A name of the predicate.
arity (int): The arity of the predicate.
dtypes (List[DataTypes]): The data types of the arguments for the predicate.
"""
def __init__(self, name, arity, dtypes):
super(NeuralPredicate, self).__init__(name, arity, dtypes)
self.name = name
self.arity = arity
self.dtypes = dtypes
def __str__(self):
return self.name + '/' + str(self.arity) + '/' + str(self.dtypes)
def __hash__(self):
return hash(self.__str__())
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return type(other) == NeuralPredicate and self.name == other.name
def __lt__(self, other):
return self.__str__() < other.__str__()
[docs]class Atom(object):
"""Atoms in first-oder logic.
A class of atoms: p(t1, ..., tn)
Attributes:
pred (Predicate): A predicate of the atom.
terms (List[Term]): The terms for the atoms.
"""
def __init__(self, pred, terms):
assert pred.arity == len(
terms), 'Invalid arguments for predicate symbol ' + pred.name
self.pred = pred
self.terms = terms
self.neg_state = False
def __eq__(self, other):
if other == None:
return False
if self.pred == other.pred:
for i in range(len(self.terms)):
if not self.terms[i] == other.terms[i]:
return False
return True
else:
return False
def __str__(self):
s = self.pred.name + '('
for arg in self.terms:
s += arg.__str__() + ','
s = s[0:-1]
s += ')'
return s
def __hash__(self):
return hash(self.__str__())
def __repr__(self):
return self.__str__()
def __lt__(self, other):
"""comparison < """
return self.__str__() < other.__str__()
def __gt__(self, other):
"""comparison > """
return self.__str__() < other.__str__()
[docs] def subs(self, target_var, const):
self.terms = [term.subs(target_var, const) for term in self.terms]
[docs] def neg(self):
self.neg_state = not self.neg_state
[docs] def all_vars(self):
var_list = []
for term in self.terms:
# var_list.append(term.all_vars())
var_list += term.all_vars()
return var_list
[docs] def all_consts(self):
const_list = []
for term in self.terms:
const_list += term.all_consts()
return const_list
[docs] def all_funcs(self):
func_list = []
for term in self.terms:
func_list += term.all_funcs()
return func_list
[docs] def max_depth(self):
return max([term.max_depth() for term in self.terms])
[docs] def min_depth(self):
return min([term.min_depth() for term in self.terms])
[docs] def size(self):
size = 0
for term in self.terms:
size += term.size()
return size
[docs] def get_terms_by_dtype(self, dtype):
"""Return terms that have type of dtype.
Returns: (list(Term))
"""
result = []
for i, term in enumerate(self.terms):
if self.pred.dtypes[i] == dtype:
#print( self.pred.dtypes[i], dtype, self.pred.dtypes[i] == dtype)
result.append(term)
return result
[docs]class Clause(object):
"""Clauses in first-oder logic.
A class of clauses in first-order logic: A :- B1, ..., Bn.
Attributes:
head (Atom): The head atom.
body (List[Atom]): The atoms for the body.
"""
def __init__(self, head, body):
self.head = head
self.body = sorted(body)
#self.body = body
#print(self)
###self._rename()
#print(self)
def __str__(self):
head_str = self.head.__str__()
body_str = ""
for bi in self.body:
body_str += bi.__str__()
body_str += ','
body_str = body_str[0:-1]
body_str += '.'
return head_str + ':-' + body_str
def __repr__(self):
return self.__str__()
def __eq__(self, other):
#return self._id_str() == other._id_str()
return self.head == other.head and set(self.body) == set(other.body)
#return self.__str__() == other.__str__()
def __hash__(self):
return hash(self.__str__())
def __lt__(self, other):
return self.__str__() < other.__str__()
def __gt__(self, other):
return self.__str__() > other.__str__()
def _rename(self):
"""Renaming variables to compute the equality.
e.g. p(O1,O2):-. == p(O2,O3):-. == p(__X1__,__X2__):-.
"""
atoms = [self.head] + self.body
vars = self.all_vars()
id_vars = [Var("_X" + str(i) + "_") for i in range(len(vars))]
head_terms = []
for term in self.head.terms:
if term.is_var():
head_terms.append(id_vars[vars.index(term)])
else:
head_terms.append(term)
head_ = Atom(self.head.pred, head_terms)
body_ = []
for bi in self.body:
bi_terms = []
for term in bi.terms:
if term.is_var():
bi_terms.append(id_vars[vars.index(term)])
else:
bi_terms.append(term)
bi_atom = Atom(bi.pred, bi_terms)
body_.append(bi_atom)
self.head = head_
self.body = sorted(body_)
def _id_str(self):
"""Renaming variables to compute the equality.
e.g. p(O1,O2):-. == p(O2,O3):-. == p(__X1__,__X2__):-.
"""
atoms = [self.head] + self.body
vars = list(set(self.all_vars()))
id_vars = [Var("__X" + str(i) + "__") for i in range(len(vars))]
head_terms = []
for term in self.head.terms:
if term.is_var():
head_terms.append(id_vars[vars.index(term)])
else:
head_terms.append(term)
head_ = Atom(self.head.pred, head_terms)
body_ = []
for bi in self.body:
bi_terms = []
for term in bi.terms:
if term.is_var():
bi_terms.append(id_vars[vars.index(term)])
else:
bi_terms.append(term)
bi_atom = Atom(bi.pred, bi_terms)
body_.append(bi_atom)
# to str
head_str = head_.__str__()
body_str = ""
for bi in body_:
body_str += bi.__str__()
body_str += ','
body_str = body_str[0:-1]
body_str += '.'
return head_str + ':-' + body_str
[docs] def is_tautology(self):
return len(self.body) == 1 and self.body[0] == self.head
[docs] def is_duplicate(self):
if len(self.body) >= 2:
es = self.body
return es == [es[0]] * len(es) if es else False
return False
[docs] def subs(self, target_var, const):
if type(self.head) == Atom:
self.head.subs(target_var, const)
for bi in self.body:
bi.subs(target_var, const)
[docs] def all_vars(self):
var_list = []
var_list += self.head.all_vars()
for bi in self.body:
var_list += bi.all_vars()
var_list = flatten(var_list)
# remove duplication
result = []
for v in var_list:
if not v in result:
result.append(v)
return result
[docs] def all_vars_by_dtype(self, dtype):
"""Get all variables in the clause that has given data type.
Returns: list(Var)
"""
atoms = [self.head] + self.body
result = []
for atom in atoms:
terms = atom.get_terms_by_dtype(dtype)
vars = [t for t in terms if t.is_var()]
result.extend(vars)
return sorted(list(set(result)))
[docs] def count_by_predicate(self, pred):
atoms = [self.head] + self.body
n = 0
for atom in atoms:
if pred == atom.pred:
n += 1
return n
[docs] def all_consts(self):
const_list = []
const_list += self.head.all_consts()
for bi in self.body:
const_list += bi.all_consts()
const_list = flatten(const_list)
return const_list
[docs] def all_funcs(self):
func_list = []
func_list += self.head.all_funcs()
for bi in self.body:
func_list += bi.all_funcs()
func_list = flatten(func_list)
return func_list
[docs] def max_depth(self):
depth_list = [self.head.max_depth()]
for b in self.body:
depth_list.append(b.max_depth())
return max(depth_list)
[docs] def min_depth(self):
depth_list = [self.head.min_depth()]
for b in self.body:
depth_list.append(b.min_depth())
return min(depth_list)
[docs] def size(self):
size = self.head.size()
for bi in self.body:
size += bi.size()
return size