4-Channel YOLO Training Guide For RGB+IR Drone Detection
4-Channel YOLO Training Guide For RGB+IR Drone Detection
python
class RGBIRDataset(Dataset):
def __init__(self, rgb_dir, ir_dir, label_dir, img_size=640, augment=False):
self.rgb_dir = Path(rgb_dir)
self.ir_dir = Path(ir_dir)
self.label_dir = Path(label_dir)
self.img_size = img_size
self.augment = augment
# Get list of RGB images (assuming RGB and IR have same names)
self.rgb_files = list(self.rgb_dir.glob('*.jpg')) + list(self.rgb_dir.glob('*.png'))
def __len__(self):
return len(self.rgb_files)
# Resize images
rgb_img = cv2.resize(rgb_img, (self.img_size, self.img_size))
ir_img = cv2.resize(ir_img, (self.img_size, self.img_size))
# Load labels
label_path = self.label_dir / (rgb_path.stem + '.txt')
labels = []
if label_path.exists():
with open(label_path, 'r') as f:
for line in f.readlines():
labels.append(list(map(float, line.strip().split())))
return torch.from_numpy(combined_img), torch.tensor(labels)
python
class Conv4Channel(nn.Module):
"""Modified Conv layer to handle 4-channel input"""
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Ide
class YOLO4Channel(nn.Module):
def __init__(self, cfg='yolov8n.yaml', ch=4, nc=1):
super().__init__()
class RGBIR_Trainer(DetectionTrainer):
def build_dataset(self, img_path, mode="train", batch=None):
"""Build custom dataset for RGB+IR training"""
gs = max(int(self.model.stride.max() if self.model else 0), 32)
return RGBIRDataset(
rgb_dir=f"{self.data['path']}/images/{mode}",
ir_dir=f"{self.data['path']}/images/{mode}", # Assuming IR images are in same fold
label_dir=f"{self.data['path']}/labels/{mode}",
img_size=self.args.imgsz,
augment=mode == "train"
)
@staticmethod
def collate_fn(batch):
"""Custom collate function for batch processing"""
imgs = []
labels = []
for img, label in batch:
imgs.append(img)
if len(label) > 0:
labels.append(label)
imgs = torch.stack(imgs, 0)
return imgs, labels
# Initialize weights
with torch.no_grad():
# Copy RGB weights (first 3 channels)
new_conv.weight[:, :3, :, :] = first_conv.weight.clone()
# Initialize IR channel (4th channel) as average of RGB channels
new_conv.weight[:, 3:4, :, :] = first_conv.weight.mean(dim=1, keepdim=True)
return model
# Find all RGB images (assuming naming convention: *_rgb.jpg and *_ir.jpg)
self.image_files = []
for rgb_file in self.images_dir.glob('*_rgb.*'):
ir_file = self.images_dir / rgb_file.name.replace('_rgb', '_ir')
if ir_file.exists():
self.image_files.append((rgb_file, ir_file))
def __len__(self):
return len(self.image_files)
# Load IR image
ir_img = cv2.imread(str(ir_path), cv2.IMREAD_GRAYSCALE)
ir_img = cv2.resize(ir_img, (self.img_size, self.img_size))
# Combine RGB + IR
combined = np.dstack([rgb_img, ir_img]) # Shape: (H, W, 4)
combined = combined.transpose(2, 0, 1) # Shape: (4, H, W)
combined = combined.astype(np.float32) / 255.0
# Load labels
label_path = self.labels_dir / (rgb_path.stem.replace('_rgb', '') + '.txt')
labels = []
if label_path.exists():
with open(label_path, 'r') as f:
for line in f.readlines():
if line.strip():
labels.append([float(x) for x in line.strip().split()])
# Training function
def train_4channel_yolo(model, train_dataset, val_dataset, epochs=50, batch_size=8):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on device: {device}")
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=2
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn,
num_workers=2
)
# Optimizer
optimizer = torch.optim.AdamW(model.model.parameters(), lr=0.001, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# Training loop
for epoch in range(epochs):
model.model.train()
total_loss = 0
num_batches = 0
for batch_idx, (images, targets) in enumerate(train_loader):
images = images.to(device)
optimizer.zero_grad()
# Forward pass
try:
# For YOLO models, we need to handle the loss calculation differently
# This is a simplified approach - you might need to adapt based on your specifi
outputs = model.model(images)
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % 10 == 0:
print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Lo
except Exception as e:
print(f"Error in batch {batch_idx}: {e}")
continue
scheduler.step()
avg_loss = total_loss / max(num_batches, 1)
print(f'Epoch {epoch+1}/{epochs} completed. Average Loss: {avg_loss:.4f}')
# Save checkpoint
if (epoch + 1) % 10 == 0:
checkpoint_path = f'/content/checkpoint_epoch_{epoch+1}.pt'
torch.save(model.model.state_dict(), checkpoint_path)
print(f'Checkpoint saved: {checkpoint_path}')
# Example usage
print("Setting up datasets...")
train_dataset = RGBIRDataset('/content/dataset/images/train', '/content/dataset/labels/train')
val_dataset = RGBIRDataset('/content/dataset/images/val', '/content/dataset/labels/val')
# Start training
if len(train_dataset) > 0:
print("Starting training...")
train_4channel_yolo(model, train_dataset, val_dataset, epochs=50, batch_size=4)
else:
print("No training data found. Please check your dataset structure.")
def setup_ultralytics_training():
"""Setup training using Ultralytics framework with modifications"""
with torch.no_grad():
new_conv.weight[:, :3] = first_conv.weight
new_conv.weight[:, 3:4] = first_conv.weight.mean(dim=1, keepdim=True)
model.model[0].conv = new_conv
return model
# For this approach, you'll need to preprocess your data to 4-channel format first
def preprocess_dataset_to_4channel():
"""Convert RGB+IR pairs to 4-channel images"""
import glob
train_rgb_files = glob.glob('/content/dataset/images/train/*_rgb.*')
val_rgb_files = glob.glob('/content/dataset/images/val/*_rgb.*')
if os.path.exists(ir_path):
# Load images
rgb = cv2.imread(rgb_path)
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
ir = cv2.imread(ir_path, cv2.IMREAD_GRAYSCALE)
# Combine to 4-channel
combined = np.dstack([rgb, ir])
process_files(train_rgb_files, '/content/dataset/processed/train')
process_files(val_rgb_files, '/content/dataset/processed/val')
def test_4channel_model():
"""Test if the 4-channel model works correctly"""
# Test model
model = create_4channel_yolo_model()
model.model.eval()
try:
with torch.no_grad():
output = model.model(dummy_input)
print("✅ 4-channel model test passed!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output type: {type(output)}")
return True
except Exception as e:
print(f"❌ Model test failed: {e}")
return False
# Run test
test_4channel_model()
# Load IR image
ir = cv2.imread(ir_path, cv2.IMREAD_GRAYSCALE)
ir = cv2.resize(ir, (size, size))
ir = np.expand_dims(ir, axis=2) # Add channel dimension
# Combine
combined = np.concatenate([rgb, ir], axis=2) # Shape: (640, 640, 4)
combined = combined.astype(np.float32) / 255.0
combined = np.transpose(combined, (2, 0, 1)) # Shape: (4, 640, 640)
return combined
# RGB image
rgb = cv2.imread(rgb_path)
rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
axes[0].imshow(rgb)
axes[0].set_title('RGB Image')
axes[0].axis('off')
# IR image
ir = cv2.imread(ir_path, cv2.IMREAD_GRAYSCALE)
axes[1].imshow(ir, cmap='hot')
axes[1].set_title('IR Image')
axes[1].axis('off')
python
with torch.no_grad():
for batch_idx, (data, targets) in enumerate(test_loader):
data = data.to(device)
outputs = model(data)
# Preprocess images
combined_img = preprocess_rgbir_image(rgb_path, ir_path)
combined_img = torch.from_numpy(combined_img).unsqueeze(0).to(device)
# Run inference
with torch.no_grad():
results = model(combined_img)
return results
print("Training completed!")
Important Notes:
1. Data Format: Ensure RGB and IR images have the same filename structure
2. Memory Management: 4-channel processing requires more GPU memory
5. Loss Function: You may need to implement custom loss calculation for the training loop
6. Validation: Implement proper validation metrics for drone detection
Troubleshooting:
OOM Errors: Reduce batch size and image size