Skip to content

Instantly share code, notes, and snippets.

@davispuh
Last active February 17, 2022 03:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save davispuh/a1236904a2dc42bfa5d100a654c043e7 to your computer and use it in GitHub Desktop.
Save davispuh/a1236904a2dc42bfa5d100a654c043e7 to your computer and use it in GitHub Desktop.
Script to automaticaly unbind devices from host for PCI passthrough, also will bind back after VM is turned off
#!/usr/bin/env ruby
require 'pathname'
require 'open3'
require 'nokogiri'
require 'ostruct'
require 'shellwords'
require 'timeout'
VM_NAME = 'WindowsVM'
VIRSH_URI = 'qemu:///system'
DISPLAY_SERVER = 'sddm'
TIMEOUT = 60
PCI_PATH = Pathname.new('/sys/bus/pci')
PCI_DEVICE_PATH = PCI_PATH / 'devices'
VFIO_PCI_DRIVER = 'vfio-pci'
DRIVER_PROBE_PATH = '/sys/bus/pci/drivers_probe'
VTCONSOLE_PATH = Pathname.new('/sys/class/vtconsole/')
EFIFB_PATH = Pathname.new('/sys/bus/platform/drivers/efi-framebuffer/')
EFIFB_ID = 'efi-framebuffer.0'
$stderr.sync = true
$stdout.sync = true
def self.infoKernelBug
$stderr.puts("Most likely means we hit a kernel bug, check your dmesg!")
end
def self.writeData(file, data, ignoreErrors = false)
Timeout::timeout(TIMEOUT) do
File.open(file, 'wb') do |io|
io.write_nonblock(data)
end
end
rescue Timeout::Error => error
raise error unless ignoreErrors
$stderr.puts("This is taking way too long! Tried to write #{data.inspect} into #{file}")
self.infoKernelBug
end
def self.getVirshCMD(command)
cmd = ['virsh']
cmd += ['-c', VIRSH_URI] unless VIRSH_URI.to_s.empty?
cmd << command
cmd << VM_NAME
cmd
end
def self.getVMConfig
puts("Loading #{VM_NAME} config...")
cmd = self.getVirshCMD('dumpxml')
output, status = Open3.capture2e(*cmd)
unless status.success?
$stderr.puts(output)
$stderr.puts('Failed to get VM config!')
exit(1)
end
vmConfig = Nokogiri::XML(output)
puts "Config loaded!"
vmConfig
end
def self.findPCIDevices(config)
deviceIds = []
config.xpath('//devices/hostdev[@type="pci"]/source/address').each do |node|
params = node.to_h.transform_values { |v| v[2..] }
deviceIds << [params['domain'], params['bus'], params['slot'] + '.' + params['function']].join(':')
end
deviceIds
end
def self.pciExist?(deviceId)
unless (PCI_DEVICE_PATH / deviceId).directory?
$stderr.puts("Didn't find #{deviceId} device!")
exit(4)
end
true
end
def self.hasIOMMU?(deviceId)
unless (PCI_DEVICE_PATH / deviceId / 'iommu_group').directory?
$stderr.puts("It seems IOMMU isn't enabled!")
exit(5)
end
true
end
def self.findGroupedDevices(deviceId)
return unless self.pciExist?(deviceId)
return unless self.hasIOMMU?(deviceId)
devicePath = PCI_DEVICE_PATH / deviceId / 'iommu_group' / 'devices'
devicePath.children
.select { |device| device.symlink? }
.map { |device| device.basename.to_s }
end
def self.getPCIConfig(deviceId)
return unless self.pciExist?(deviceId)
fields = %i{vendorId deviceId commandRegister statusRegister revisionId
subclass classCode cacheLine latencyTimer
headerType bist}
data = File.read(PCI_DEVICE_PATH / deviceId / 'config')
values = data.unpack("S<S<S<S<CCS<CCCC")
OpenStruct.new(Hash[fields.zip(values)])
end
MULTIPLE_FUNCTIONS_BIT = 1 << 7
TYPE_MASK = MULTIPLE_FUNCTIONS_BIT ^ 0xFF
def self.isRegularPCI?(deviceId)
config = self.getPCIConfig(deviceId)
config.headerType & TYPE_MASK == 0
end
def self.getIOMMUID(deviceIds)
(PCI_DEVICE_PATH / deviceIds.first / 'iommu_group').readlink.basename.to_s
end
def self.getVFIODevices
vmConfig = self.getVMConfig
deviceIds = self.findPCIDevices(vmConfig)
if deviceIds.empty?
$stderr.puts("Warning! VM isn't using any PCI devices!")
end
puts 'Looking for PCI device IOMMU groups...'
iommuGroups = {}
deviceIds.each do |deviceId|
groupedDevices = self.findGroupedDevices(deviceId)
# Keep only regular PCI devices, remove PCI-to-PCI Bridges
groupedDevices.select! { |deviceId| self.isRegularPCI?(deviceId) }
iommuID = self.getIOMMUID(groupedDevices)
iommuGroups[iommuID] = groupedDevices.sort
end
vfioDeviceIds = []
iommuGroups.keys.sort.each do |iommuID|
puts "IOMMU group #{iommuID}: " + iommuGroups[iommuID].join(', ')
vfioDeviceIds += iommuGroups[iommuID]
end
vfioDeviceIds
end
def self.isVMRunning?
cmd = self.getVirshCMD('domstate')
output, status = Open3.capture2e(*cmd)
if status.success?
output.strip != 'shut off'
else
$stderr.puts(output)
$stderr.puts("Failed to get VM state!")
false
end
end
def self.loadModule(name, exitOnError = false)
puts "Loading #{name} module..."
output, status = Open3.capture2e('modprobe', name)
if status.success?
puts "Module loaded!"
else
$stderr.puts(output)
$stderr.puts('Failed load module!')
exit(5) if exitOnError
end
end
def self.unloadModule(name)
puts "Unloading #{name} module..."
output, status = Open3.capture2e('rmmod', name)
if status.success? || /is not currently loaded/.match?(output)
puts "Module unloaded!"
else
$stderr.puts(output)
$stderr.puts('Failed unload module!')
exit(4)
end
end
def self.loadVFIO
self.loadModule('vfio_pci', true)
end
def self.findGPUs
gpus = []
output, status = Open3.capture2e('lspci', '-vnD')
if status.success?
output.each_line do |line|
if /\[VGA controller\]/.match?(line)
gpus << line.split(' ').first.strip
end
end
else
$stderr.puts(output)
$stderr.puts('Failed to list PCI devices!')
exit(1)
end
gpus
end
def self.stopDisplayServer
puts "Stopping display server..."
output, status = Open3.capture2e('systemctl', 'stop', DISPLAY_SERVER)
if status.success?
puts "Display server stopped!"
else
$stderr.puts(output)
$stderr.puts('Failed to stop display server!')
exit(1)
end
end
def self.startDisplayServer
puts "Starting display server..."
output, status = Open3.capture2e('systemctl', 'start', DISPLAY_SERVER)
if status.success?
puts "Display server started!"
else
$stderr.puts(output)
$stderr.puts('Failed to start display server!')
end
end
def self.getVTConsoles
VTCONSOLE_PATH.children
end
def self.findFBConsole
console = self.getVTConsoles.find do |console|
/frame buffer device/.match?(File.read(console / 'name'))
end
if !console
$stderr.puts("Didn't find VT console!")
end
console
end
def self.unbindVTconsoles
console = self.findFBConsole
if console
puts "Unbinding #{console.basename}..."
bindPath = console / 'bind'
self.writeData(bindPath, '0')
if File.read(bindPath).strip == '0'
puts "Unbinded!"
else
$stderr.puts('Failed to unbind VT console!')
exit(2)
end
end
end
def self.bindVTconsoles
console = self.findFBConsole
if console
puts "Binding #{console.basename}..."
success = false
bindPath = console / 'bind'
10.times do
self.writeData(bindPath, '1')
sleep(1)
success = File.read(bindPath).strip == '1'
break if success
end
if success
puts "Binded!"
else
$stderr.puts('Failed to bind VT console!')
end
end
end
def self.unbindEFIFB
if (EFIFB_PATH / EFIFB_ID).symlink?
puts "Unbinding EFI framebuffer..."
self.writeData(EFIFB_PATH / 'unbind', EFIFB_ID)
puts "Unbinded!"
else
puts 'EFI framebuffer not present (probably already unbinded)'
end
end
def self.bindEFIFB
puts "Binding EFI framebuffer..."
begin
self.writeData(EFIFB_PATH / 'bind', EFIFB_ID)
puts "Binded!"
rescue Errno::EINVAL
$stderr.puts('Failed to bind EFI framebuffer!')
end
end
def self.unloadModulesAMD
self.unloadModule('amdgpu')
self.unloadModule('drm_ttm_helper')
self.unloadModule('ttm')
#self.unloadModule('drm_kms_helper')
end
def self.loadModulesAMD
self.loadModule('amdgpu')
end
def self.unloadGPU(deviceIds)
gpus = self.findGPUs
if (deviceIds & gpus).empty?
puts("No GPU for VM! Won't unload GPU!")
return
end
stopDisplayServer
unbindVTconsoles
unbindEFIFB
unloadModulesAMD
end
def self.loadGPU(deviceIds)
gpus = (self.findGPUs & deviceIds)
if gpus.empty?
puts("No GPU for VM!")
return
end
loadModulesAMD
gpus.each do |deviceId|
restoreDriver(deviceId)
end
bindEFIFB
bindVTconsoles
startDisplayServer
end
def self.getDeviceDriver(deviceId)
driverFolder = PCI_DEVICE_PATH / deviceId / 'driver'
return '' unless driverFolder.symlink?
driverFolder.readlink.basename.to_s
end
def self.probeDriver(deviceId, ignoreErrors = false)
self.writeData(DRIVER_PROBE_PATH, deviceId, ignoreErrors)
end
def self.rescanPCI(ignoreErrors = false)
self.writeData(PCI_PATH / 'rescan', '1', ignoreErrors)
end
def self.overrideDriver(deviceId, driver, ignoreErrors = false)
devicePath = PCI_DEVICE_PATH / deviceId
overridePath = devicePath / 'driver_override'
self.writeData(overridePath, driver, ignoreErrors)
driverPath = devicePath / 'driver'
if driverPath.symlink?
unbindPath = driverPath / 'unbind'
self.writeData(unbindPath, deviceId, ignoreErrors)
end
self.probeDriver(deviceId, ignoreErrors)
end
def self.restoreDriver(deviceId)
self.overrideDriver(deviceId, "\n", true)
end
def self.bindVFIO(deviceId)
currentDriver = self.getDeviceDriver(deviceId)
if currentDriver != VFIO_PCI_DRIVER
puts "Overriding device's #{deviceId} driver to #{VFIO_PCI_DRIVER}"
self.overrideDriver(deviceId, VFIO_PCI_DRIVER)
newDriver = self.getDeviceDriver(deviceId)
if (newDriver != VFIO_PCI_DRIVER)
$stderr.puts('Failed to bind driver!')
exit(5)
end
else
puts "Device #{deviceId} is already using #{VFIO_PCI_DRIVER}!"
end
end
def self.unbindVFIO(deviceId)
currentDriver = self.getDeviceDriver(deviceId)
if currentDriver.empty? || currentDriver == VFIO_PCI_DRIVER
puts "Restoring device's #{deviceId} driver"
self.restoreDriver(deviceId)
restoredDriver = self.getDeviceDriver(deviceId)
if (restoredDriver == VFIO_PCI_DRIVER)
$stderr.puts('Failed to unbind driver!')
end
else
puts "Device #{deviceId} is already using correct #{currentDriver}!"
end
end
def self.bindAll(deviceIds)
deviceIds.each do |deviceId|
self.bindVFIO(deviceId)
end
end
def self.unbindAll(deviceIds)
deviceIds.each do |deviceId|
self.unbindVFIO(deviceId)
end
end
def self.startVM(attachConsole = true)
deviceIds = self.getVFIODevices
if self.isVMRunning?
puts 'VM is already running! Waiting for it to stop...'
loop do
sleep(30)
break unless self.isVMRunning?
end
else
begin
self.loadVFIO
self.unloadGPU(deviceIds)
self.bindAll(deviceIds)
self.rescanPCI
cmd = self.getVirshCMD('start')
cmd << '--console' if attachConsole
if system(cmd.shelljoin) && !attachConsole
loop do
sleep(80)
break unless self.isVMRunning?
end
end
rescue Timeout::Error
$stderr.puts("This is taking way too long! Aborting!")
self.showTimeoutError
end
end
self.unbindAll(deviceIds)
self.rescanPCI(true)
self.loadGPU(deviceIds)
end
def self.stopVM
if self.isVMRunning?
puts "Shutting down #{VM_NAME}..."
sleep(30) if system(self.getVirshCMD('shutdown').shelljoin)
if self.isVMRunning?
puts "VM didn't stop in given time, will force stop!"
system(self.getVirshCMD('destroy').shelljoin)
sleep(30)
$stderr.puts("Failed to stop VM!") if self.isVMRunning?
end
else
puts "#{VM_NAME} is not running!"
end
deviceIds = self.getVFIODevices
self.unbindAll(deviceIds)
self.rescanPCI(true)
self.loadGPU(deviceIds)
end
def main
command = ARGV.first
if command == 'start'
unless system([RbConfig.ruby, $0, 'directStart'].shelljoin)
self.stopVM
end
elsif command == 'stop'
self.stopVM
elsif command == 'directStart'
self.startVM(false)
else
puts 'Commands: start or stop'
end
end
main unless $spec
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment