Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added segmentation based imitation learning model as a new brain #632

Merged
merged 8 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from torchvision import transforms
from PIL import Image
from brains.CARLA.utils.pilotnet_onehot import PilotNetOneHot
from brains.CARLA.utils.test_utils import traffic_light_to_int, model_control
from utils.constants import PRETRAINED_MODELS_DIR, ROOT_PATH
from os import path

import numpy as np

import torch
import torchvision
import cv2
import time
import os
import math
import carla

PRETRAINED_MODELS = ROOT_PATH + '/' + PRETRAINED_MODELS_DIR + 'il_models/'
Qi-Zha0 marked this conversation as resolved.
Show resolved Hide resolved

class Brain:

def __init__(self, sensors, actuators, model=None, handler=None, config=None):
self.motors = actuators.get_motor('motors_0')
self.camera_rgb = sensors.get_camera('camera_0') # rgb front view camera
self.camera_seg = sensors.get_camera('camera_2') # segmentation camera
self.handler = handler
self.inference_times = []
self.gpu_inference = config['GPU']
self.device = torch.device('cuda' if (torch.cuda.is_available() and self.gpu_inference) else 'cpu')

client = carla.Client('localhost', 2000)
client.set_timeout(10.0)
world = client.get_world()
self.map = world.get_map()

weather = carla.WeatherParameters.ClearNoon
world.set_weather(weather)

self.vehicle = None
while self.vehicle is None:
for vehicle in world.get_actors().filter('vehicle.*'):
if vehicle.attributes.get('role_name') == 'ego_vehicle':
self.vehicle = vehicle
break
if self.vehicle is None:
print("Waiting for vehicle with role_name 'ego_vehicle'")
time.sleep(1) # sleep for 1 second before checking again

if model:
if not path.exists(PRETRAINED_MODELS + model):
print("File " + model + " cannot be found in " + PRETRAINED_MODELS)

if config['UseOptimized']:
self.net = torch.jit.load(PRETRAINED_MODELS + model).to(self.device)
else:
self.net = PilotNetOneHot((288, 200, 6), 3, 4, 4).to(self.device)
self.net.load_state_dict(torch.load(PRETRAINED_MODELS + model,map_location=self.device))
self.net.eval()

self.prev_hlc = 0


def update_frame(self, frame_id, data):
"""Update the information to be shown in one of the GUI's frames.

Arguments:
frame_id {str} -- Id of the frame that will represent the data
data {*} -- Data to be shown in the frame. Depending on the type of frame (rgbimage, laser, pose3d, etc)
"""
if data.shape[0] != data.shape[1]:
if data.shape[0] > data.shape[1]:
difference = data.shape[0] - data.shape[1]
extra_left, extra_right = int(difference/2), int(difference/2)
extra_top, extra_bottom = 0, 0
else:
difference = data.shape[1] - data.shape[0]
extra_left, extra_right = 0, 0
extra_top, extra_bottom = int(difference/2), int(difference/2)


data = np.pad(data, ((extra_top, extra_bottom), (extra_left, extra_right), (0, 0)), mode='constant', constant_values=0)

self.handler.update_frame(frame_id, data)

def execute(self):
"""Main loop of the brain. This will be called iteratively each TIME_CYCLE (see pilot.py)"""

rgb_image = self.camera_rgb.getImage().data
seg_image = self.camera_seg.getImage().data

self.update_frame('frame_0', rgb_image)
self.update_frame('frame_1', seg_image)

try:
# calculate speed
speed_m_s = self.vehicle.get_velocity()
speed = 3.6 * math.sqrt(speed_m_s.x**2 + speed_m_s.y**2 + speed_m_s.z**2)

# randomly choose high-level command if at junction
vehicle_location = self.vehicle.get_transform().location
vehicle_waypoint = self.map.get_waypoint(vehicle_location)
next_to_junction = False
for j in range(1, 11):
next_waypoint = vehicle_waypoint.next(j * 1.0)[0]
if next_waypoint.is_junction:
next_to_junction = True
next_waypoints = vehicle_waypoint.next(j * 1.0)
break
if vehicle_waypoint.is_junction or next_to_junction:
if self.prev_hlc == 0:
valid_turns = []
for next_wp in next_waypoints:
yaw_diff = next_wp.transform.rotation.yaw - vehicle_waypoint.transform.rotation.yaw
yaw_diff = (yaw_diff + 180) % 360 - 180
if -15 < yaw_diff < 15:
valid_turns.append(3) # Go Straight
elif 15 < yaw_diff < 165:
valid_turns.append(1) # Turn Left
elif -165 < yaw_diff < -15:
valid_turns.append(2) # Turn Right
hlc = np.random.choice(valid_turns)
else:
hlc = self.prev_hlc
else:
hlc = 0

# get traffic light status
light_status = -1
if self.vehicle.is_at_traffic_light():
traffic_light = self.vehicle.get_traffic_light()
light_status = traffic_light.get_state()

print(f'hlc: {hlc}')
print(f'light: {light_status}')
frame_data = {
'hlc': hlc,
'measurements': speed,
'rgb': np.copy(rgb_image),
'segmentation': np.copy(seg_image),
'light': np.array([traffic_light_to_int(light_status)])
}

throttle, steer, brake = model_control(self.net,
frame_data,
ignore_traffic_light=False,
device=self.device,
combined_control=False)
self.motors.sendThrottle(throttle)
self.motors.sendSteer(steer)
self.motors.sendBrake(brake)
except Exception as err:
print(err)
158 changes: 158 additions & 0 deletions behavior_metrics/brains/CARLA/pytorch/utils/pilotnet_onehot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import torch
import torch.nn as nn
import sys
import os

from .convlstm import ConvLSTM


class PilotNetOneHot(nn.Module):
def __init__(self, image_shape, num_labels, num_hlc, num_light):
super(PilotNetOneHot, self).__init__()
self.num_channels = image_shape[2]
self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2)
self.relu_1 = nn.ReLU()
self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.relu_2 = nn.ReLU()
self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.relu_3 = nn.ReLU()
self.dropout_1 = nn.Dropout(0.2)

self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False)

self.fc_1 = nn.Linear(8*35*24+1+num_hlc+num_light, 50) # add embedding layer output size
self.relu_fc_1 = nn.ReLU()
self.fc_2 = nn.Linear(50, 10)
self.relu_fc_2 = nn.ReLU()
self.fc_3 = nn.Linear(10, num_labels)

def forward(self, img, speed, hlc, light):
out = self.cn_1(img)
out = self.relu_1(out)
out = self.cn_2(out)
out = self.relu_2(out)
out = self.cn_3(out)
out = self.relu_3(out)
out = self.dropout_1(out)
out = out.unsqueeze(1) # add additional dimension at 1

_, last_states = self.clstm_n(out)
out = last_states[0][0] # 0 for layer index, 0 for h index

# flatten & concatenate with speed
out = out.reshape(out.size(0), -1)
speed = speed.view(speed.size(0), -1)
hlc = hlc.view(hlc.size(0), -1)
light = light.view(light.size(0), -1)
out = torch.cat((out, speed, hlc, light), dim=1) # Concatenate the high-level commands to the rest of the inputs

out = self.fc_1(out)
out = self.relu_fc_1(out)
out = self.fc_2(out)
out = self.relu_fc_2(out)
out = self.fc_3(out)

out = torch.sigmoid(out)

return out

class PilotNetOneHotNoLight(nn.Module):
def __init__(self, image_shape, num_labels, num_hlc):
super(PilotNetOneHotNoLight, self).__init__()
self.num_channels = image_shape[2]
self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2)
self.relu_1 = nn.ReLU()
self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.relu_2 = nn.ReLU()
self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.relu_3 = nn.ReLU()
self.dropout_1 = nn.Dropout(0.2)

self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False)

self.fc_1 = nn.Linear(8*35*24+1+num_hlc, 50) # add embedding layer output size
self.relu_fc_1 = nn.ReLU()
self.fc_2 = nn.Linear(50, 10)
self.relu_fc_2 = nn.ReLU()
self.fc_3 = nn.Linear(10, num_labels)

def forward(self, img, speed, hlc):
out = self.cn_1(img)
out = self.relu_1(out)
out = self.cn_2(out)
out = self.relu_2(out)
out = self.cn_3(out)
out = self.relu_3(out)
out = self.dropout_1(out)
out = out.unsqueeze(1) # add additional dimension at 1

_, last_states = self.clstm_n(out)
out = last_states[0][0] # 0 for layer index, 0 for h index

# flatten & concatenate with speed
out = out.reshape(out.size(0), -1)
speed = speed.view(speed.size(0), -1)
hlc = hlc.view(hlc.size(0), -1)
out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs
out = self.fc_1(out)
out = self.relu_fc_1(out)
out = self.fc_2(out)
out = self.relu_fc_2(out)
out = self.fc_3(out)

out = torch.sigmoid(out)

return out


class PilotNetEmbeddingNoLight(nn.Module):
def __init__(self, image_shape, num_labels, num_hlc):
super(PilotNetEmbeddingNoLight, self).__init__()
self.num_channels = image_shape[2]
self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2)
self.relu_1 = nn.ReLU()
self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.relu_2 = nn.ReLU()
self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2)
self.relu_3 = nn.ReLU()
self.dropout_1 = nn.Dropout(0.2)

self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False)

# Add embedding layer for high-level commands
self.embedding = nn.Embedding(num_hlc, 5)

self.fc_1 = nn.Linear(8*35*24+1+5, 50) # add embedding layer output size
self.relu_fc_1 = nn.ReLU()
self.fc_2 = nn.Linear(50, 10)
self.relu_fc_2 = nn.ReLU()
self.fc_3 = nn.Linear(10, num_labels)

def forward(self, img, speed, hlc):
out = self.cn_1(img)
out = self.relu_1(out)
out = self.cn_2(out)
out = self.relu_2(out)
out = self.cn_3(out)
out = self.relu_3(out)
out = self.dropout_1(out)
out = out.unsqueeze(1) # add additional dimension at 1

_, last_states = self.clstm_n(out)
out = last_states[0][0] # 0 for layer index, 0 for h index

# flatten & concatenate with speed
out = out.reshape(out.size(0), -1)
speed = speed.view(speed.size(0), -1)
hlc = self.embedding(hlc).view(hlc.size(0), -1) # convert hlc to dense vectors using the embedding layer
out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs

out = self.fc_1(out)
out = self.relu_fc_1(out)
out = self.fc_2(out)
out = self.relu_fc_2(out)
out = self.fc_3(out)

out = torch.sigmoid(out)

return out
Loading
Loading