Source code for src.optimizer

import random
import time
import torch
from collections import OrderedDict
from tqdm import tqdm

random.seed(a=1)



[docs]class WeightOptimizer(object): """ optimizer of clause weights using gradient descent """ def __init__(self, infer_module, train_idxs, labels, lr=1e-2, wd=0.0): self.IM = infer_module self.train_idxs = train_idxs self.labels = labels self.lr = lr self.wd = wd self.batch_size = 0.05 self.bce_loss = torch.nn.BCELoss() self.set_optimizer(self.IM.Ws)
[docs] def set_optimizer(self, params): """ set torch optimizer """ self.optimizer = torch.optim.RMSprop( params, lr=self.lr, weight_decay=self.wd)
[docs] def minibatch(self, probs, labels): """ get minibatch Inputs ------ probs : torch.tensor((|train_idxs|, )) valuation vector of examples each dimension represents each example of the ilp problem labels : torch.tensor((|train_idxs|, )) label vector of examples each dimension represents each example of the ilp problem Returns ------- probs_batch : torch.tensor((batch_size, )) valuation vector of examples selected in the minibatch each dimension represents each example of the ilp problem labels_batch : torch.tensor((batch_size, )) label vector of examples selected in the minibatch each dimension represents each example of the ilp problem """ ls = list(range(len(probs))) batch_num = int(len(probs)*self.batch_size) ls_batch = torch.tensor(random.sample(ls, batch_num)).to(device) return probs[ls_batch], labels[ls_batch]
[docs] def optimize_weights(self, epoch=500): """ perform gradient descent to optimize clause weights Inputs ------ epoch : int number of steps in gradient descent Returns ------- IM : .infer.InferModule infer module that contains optimized weight vectors loss_list : List[float] list of training loss """ best_loss = 9999 best_iter = 0 best_Ws = self.IM.Ws i = 0 loss_list = [] with tqdm(range(epoch)) as pbar: for i in pbar: valuation = self.IM.infer() probs = torch.gather(valuation, 0, self.train_idxs) probs_batch, labels_batch = self.minibatch(probs, self.labels) loss = self.bce_loss(probs_batch, labels_batch) loss_list.append(loss.item()) if loss > 0: loss.backward(retain_graph=True) self.optimizer.step() i += 1 pbar.set_postfix(OrderedDict(loss=loss.item())) return self.IM, loss_list