Source code for myGym.envs.vision_module

try:
    import torch
except:
    print("Torch doesn't work")
import sys
import numpy as np
import cv2
import random
import pkg_resources
currentdir = pkg_resources.resource_filename("myGym", "envs")

# import vision models YOLACT, VAE
sys.path.append(pkg_resources.resource_filename("myGym", "yolact_vision")) #may be moved somewhere else
try:
    from inference_tool import InfTool
except:
    print("Problem importing YOLACT.")
from myGym.vae.vis_helpers import load_checkpoint
from myGym.vae import  sample

[docs]class VisionModule: """ Vision class that retrieves information from environment based on a visual subsystem (YOLACT, VAE) or ground truth Parameters: :param vision_src: (string) Source of information from environment (ground_truth, yolact, vae) :param env: (object) Environment, where the training takes place :param vae_path: (string) Path to a trained VAE in 2dvu reward type :param yolact_path: (string) Path to a trained Yolact in 3dvu reward type :param yolact_config: (string) Path to saved Yolact config obj or name of an existing one in the data/Config script or None for autodetection """ def __init__(self, vision_src="ground_truth", env=None, vae_path=None, yolact_path=None, yolact_config=None): self.src = vision_src self.env = env self.vae_embedder = None self.vae_imsize = None self.vae_path = vae_path self.yolact_path = yolact_path self.yolact_config = yolact_config self.obsdim = None self._initialize_network(self.src) self.mask = {} self.centroid = {} self.centroid_transformed = {}
[docs] def get_module_type(self): """ Get source of the information from environment (ground_truth, yolact, vae) Returns: :return source: (string) Source of information """ return self.src
[docs] def crop_image(self, img): """ Crop image by 1/4 from each side Parameters: :param img: (list) Original image Returns: :return img: (list) Cropped image """ dim1 = img.shape[0] crop1 = [int(dim1/4), int(dim1-(dim1/4))] dim2 = img.shape[1] crop2 = [int(dim2/4), int(dim2-(dim2/4))] img = img[crop1[0]:crop1[1], crop2[0]:crop2[1]] return img
[docs] def get_obj_pixel_position(self, obj=None, img=None): """ Get mask and centroid in pixel space coordinates of an object from 2D image Parameters: :param obj: (object) Object to find its mask and centroid :param img: (array) 2D input image to inference of vision model Returns: :return mask: (list) Mask of object :return centroid: (list) Centroid of object in pixel sprace coordinates """ if self.src == "ground_truth": pass elif self.src in ["dope", "yolact"]: if img is not None: if self.src == "yolact": classes, class_names, scores, boxes, masks, centroids = self.inference_yolact(img) if self.env.visualize == 1: img_numpy = self.yolact_cnn.label_image(img) cv2.imshow("Yolact(3dvs) inference", img_numpy) cv2.waitKey(1) try: self.mask[obj.get_name()] = masks[class_names.index(obj.get_name())] self.centroid[obj.get_name()] = centroids[class_names.index(obj.get_name())] #print("{} was detected".format(obj.get_name())) except: if obj.get_name() not in self.mask.keys(): self.mask[obj.get_name()] = [[-1]] self.centroid[obj.get_name()] = [-1,-1] #print("{} not detected in present image".format(obj.get_name())) return self.mask[obj.get_name()], self.centroid[obj.get_name()] elif self.src == "dope": pass # @TODO else: raise Exception("You need to provide image argument for segmentation")
[docs] def get_obj_bbox(self, obj=None, img=None): """ Get bounding box of an object from 2D image Parameters: :param obj: (object) Object to find its bounding box :param img: (array) 2D input image to inference of vision model Returns: :return bbox: (list) Bounding box of object """ if self.src == "ground_truth": if obj is not None: return obj.get_bounding_box() else: raise Exception("You need to provide obj argument to get gt bounding box") elif self.src in ["dope", "yolact"]: if img is not None: if self.src == "yolact": classes, class_names, scores, boxes, masks, centroids = self.inference_yolact(img) try: bbox = boxes[class_names.index(obj.get_name())] except: bbox = [] print("Object not detected in present image") return bbox elif self.src == "dope": pass # @TODO else: raise Exception("You need to provide image argument for bbox segmentation") else: raise Exception("{} module does not provide bounding boxes!".format(self.src))
[docs] def get_obj_position(self, obj=None, img=None, depth=None): """ Get object position in world coordinates of environment from 2D and depth image Parameters: :param obj: (object) Object to find its mask and centroid :param img: (array) 2D input image to inference of vision model :param depth: (array) Depth input image to inference of vision model Returns: :return position: (list) Centroid of object in world coordinates """ if self.src == "ground_truth": if obj is not None: return list(obj.get_position()) else: raise Exception("You need to provide obj argument to get gt position") elif self.src in ["yolact", "dope"]: if img is not None: if self.src == "yolact": img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) mask, centroid = self.get_obj_pixel_position(obj, img) centroid_transformed = self.yolact_cnn.find_3d_centroids_(mask, depth, self.env.unwrapped.cameras[self.env.active_cameras].view_x_proj) if centroid_transformed.size == 3: self.centroid_transformed[obj.get_name()] = centroid_transformed #print("{} was detected at {}".format(obj.get_name(),self.centroid_transformed[obj.get_name()])) elif obj.get_name() not in self.centroid_transformed.keys(): self.centroid_transformed[obj.get_name()] = [10, 10, 10] #print("{} was not detected, assign {}".format(obj.get_name(),self.centroid_transformed[obj.get_name()])) else: pass #print("{} was not detected, assign previous {}".format(obj.get_name(),self.centroid_transformed[obj.get_name()])) return list(self.centroid_transformed[obj.get_name()]) else: raise Exception("You need to provide image argument to infer object position") return
[docs] def get_obj_orientation(self, obj=None, img=None): """ Get object orientation in world coordinates of environment from 2D image Parameters: :param obj: (object) Object to find its mask and centroid :param img: (array) 2D input image to inference of vision model Returns: :return orientation: (list) Orientation of object in world coordinates """ if self.src == "ground_truth": if obj is not None: return obj.get_orientation() else: raise Exception("You need to provide obj argument to get gt orientation") elif self.src in ["yolact", "dope"]: if img is not None: # @TODO pass else: raise Exception("You need to provide image argument to infer orientation") return
[docs] def vae_generate_sample(self): """ Generate image as a sample of VAE latent representation Returns: :return dec_img: Generated image from VAE latent representation """ latent_z = torch.tensor([random.uniform(-2, 2) for _ in range(self.vae_embedder.n_latents)]).unsqueeze(0) decoded = self.vae_embedder.image_decoder(latent_z) img = decoded.squeeze(0).reshape(self.vae_imsize, self.vae_imsize, 3) dec_img = np.asarray((img * 255).cpu().detach(), dtype="uint8") return dec_img
[docs] def encode_with_vae(self, imgs, task="reach", decode=0): """ Encode the input image into an n-dimensional latent variable using VAE model Parameters: :param imgs: (list of arrays) Input images :param task: (string) Type of learned task (reach, push, ...) :param decode: (bool) Whether to decode encoded images from latent representation back to image array Returns: :return latent_z: (list) Latent representation of images :return dec_img: (list of arrays) Decoded images from latent representation back to image arrays """ if self.src != "vae": raise Exception("Encoding can only be done with VAE module!") imgs_input = [] for img in imgs: img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) if img.shape[0] != self.vae_imsize: res = [0,450,100, 500] if task == "reach" else [60,390,160,480] img = cv2.resize(img[res[0]:res[1],res[2]:res[3]], (self.vae_imsize, self.vae_imsize)) im = torch.tensor(img).type(torch.FloatTensor) im = im.reshape(img.shape[2], img.shape[0], img.shape[0]).unsqueeze(0)/255 imgs_input = torch.cat((imgs_input, im), dim=0) if torch.is_tensor(imgs_input) else im latent_z = self.vae_embedder.infer(imgs_input)[0].detach().cpu() dec_img = sample.decode_images(self.vae_embedder, latent_z) if decode == 1 else [] return latent_z.squeeze().tolist(), dec_img
[docs] def inference_yolact(self, img): """ Infere using YOLACT model Parameters: :param img: (array) Input 2D image Returns: :return classes: (list of ints) Classes IDs of detected objects :return class_names: (list of strings) Classes names of detected objects :return scores: (list of floats) Scores (confidence) of object detections :return boxes: (list of lists) Bounding boxes of detected objects :return masks: (list of lists) Masks of detected objects :return centroids: (list of lists) Centroids of detected objects in pixel space coordinates """ classes, class_names, scores, boxes, masks, centroids = self.yolact_cnn.raw_inference(img) return classes, class_names, scores, boxes, masks, centroids
def _initialize_network(self, network): """ Initialize pre-trained vision model and define corresponding dimension of observation data Parameters: :param network: (string) Source of information from environment (yolact, vae) """ if network == "vae": weights_pth = pkg_resources.resource_filename("myGym", self.vae_path) try: self.vae_embedder, imsize = load_checkpoint(weights_pth, use_cuda=True) except: raise Exception("For reward_type other than 'gt', you need to download pre-trained vision model and specify path to it in config. Specified {} not found.".format(self.vae_path)) self.vae_imsize = imsize self.obsdim = (2*self.vae_embedder.n_latents) + 3 elif network == "yolact": weights = pkg_resources.resource_filename("myGym", self.yolact_path) if ".obj" in self.yolact_config: config = pkg_resources.resource_filename("myGym", self.yolact_config) try: self.yolact_cnn = InfTool(weights=weights, config=config, score_threshold=0.2) except: raise Exception("For reward_type other than 'gt', you need to download pre-trained vision model and specify path to it in config. Specified {} and {} not found.".format(self.yolact_path, self.yolact_config)) self.obsdim = (len(self.env.task_objects_names) + 1) * 3 elif network == "dope": self.obsdim = (len(self.env.task_objects_names) + 1) * 7 return