Implement a new JIT for Arm devices (#6057)

* Implement a new JIT for Arm devices

* Auto-format

* Make a lot of Assembler members read-only

* More read-only

* Fix more warnings

* ObjectDisposedException.ThrowIf

* New JIT cache for platforms that enforce W^X, currently unused

* Remove unused using

* Fix assert

* Pass memory manager type around

* Safe memory manager mode support + other improvements

* Actual safe memory manager mode masking support

* PR feedback
This commit is contained in:
gdkchan 2024-01-20 11:11:28 -03:00 committed by GitHub
parent 331c07807f
commit 427b7d06b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
135 changed files with 43322 additions and 24 deletions

View file

@ -0,0 +1,15 @@
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
static class AbiConstants
{
// Some of those register have specific roles and can't be used as general purpose registers.
// X18 - Reserved for platform specific usage.
// X29 - Frame pointer.
// X30 - Return address.
// X31 - Not an actual register, in some cases maps to SP, and in others to ZR.
public const uint ReservedRegsMask = (1u << 18) | (1u << 29) | (1u << 30) | (1u << 31);
public const uint GprCalleeSavedRegsMask = 0x1ff80000; // X19 to X28
public const uint FpSimdCalleeSavedRegsMask = 0xff00; // D8 to D15
}
}

View file

@ -0,0 +1,30 @@
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
enum ArmCondition
{
Eq = 0,
Ne = 1,
GeUn = 2,
LtUn = 3,
Mi = 4,
Pl = 5,
Vs = 6,
Vc = 7,
GtUn = 8,
LeUn = 9,
Ge = 10,
Lt = 11,
Gt = 12,
Le = 13,
Al = 14,
Nv = 15,
}
static class ArmConditionExtensions
{
public static ArmCondition Invert(this ArmCondition condition)
{
return (ArmCondition)((int)condition ^ 1);
}
}
}

View file

@ -0,0 +1,14 @@
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
enum ArmExtensionType
{
Uxtb = 0,
Uxth = 1,
Uxtw = 2,
Uxtx = 3,
Sxtb = 4,
Sxth = 5,
Sxtw = 6,
Sxtx = 7,
}
}

View file

@ -0,0 +1,11 @@
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
enum ArmShiftType
{
Lsl = 0,
Lsr = 1,
Asr = 2,
Ror = 3,
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,67 @@
using System.Numerics;
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
static class CodeGenCommon
{
public static bool TryEncodeBitMask(Operand operand, out int immN, out int immS, out int immR)
{
return TryEncodeBitMask(operand.Type, operand.Value, out immN, out immS, out immR);
}
public static bool TryEncodeBitMask(OperandType type, ulong value, out int immN, out int immS, out int immR)
{
if (type == OperandType.I32)
{
value &= uint.MaxValue;
value |= value << 32;
}
return TryEncodeBitMask(value, out immN, out immS, out immR);
}
public static bool TryEncodeBitMask(ulong value, out int immN, out int immS, out int immR)
{
// Some special values also can't be encoded:
// 0 can't be encoded because we need to subtract 1 from onesCount (which would became negative if 0).
// A value with all bits set can't be encoded because it is reserved according to the spec, because:
// Any value AND all ones will be equal itself, so it's effectively a no-op.
// Any value OR all ones will be equal all ones, so one can just use MOV.
// Any value XOR all ones will be equal its inverse, so one can just use MVN.
if (value == 0 || value == ulong.MaxValue)
{
immN = 0;
immS = 0;
immR = 0;
return false;
}
// Normalize value, rotating it such that the LSB is 1: Ensures we get a complete element that has not
// been cut-in-half across the word boundary.
int rotation = BitOperations.TrailingZeroCount(value & (value + 1));
ulong rotatedValue = ulong.RotateRight(value, rotation);
// Now that we have a complete element in the LSB with the LSB = 1, determine size and number of ones
// in element.
int elementSize = BitOperations.TrailingZeroCount(rotatedValue & (rotatedValue + 1));
int onesInElement = BitOperations.TrailingZeroCount(~rotatedValue);
// Check the value is repeating; also ensures element size is a power of two.
if (ulong.RotateRight(value, elementSize) != value)
{
immN = 0;
immS = 0;
immR = 0;
return false;
}
immN = (elementSize >> 6) & 1;
immS = (((~elementSize + 1) << 1) | (onesInElement - 1)) & 0x3f;
immR = (elementSize - rotation) & (elementSize - 1);
return true;
}
}
}

View file

@ -0,0 +1,252 @@
using System.Numerics;
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
readonly struct RegisterSaveRestore
{
private const int FpRegister = 29;
private const int LrRegister = 30;
public const int Encodable9BitsOffsetLimit = 0x100;
private readonly uint _gprMask;
private readonly uint _fpSimdMask;
private readonly OperandType _fpSimdType;
private readonly int _reservedStackSize;
private readonly bool _hasCall;
public RegisterSaveRestore(
uint gprMask,
uint fpSimdMask = 0,
OperandType fpSimdType = OperandType.FP64,
bool hasCall = false,
int reservedStackSize = 0)
{
_gprMask = gprMask;
_fpSimdMask = fpSimdMask;
_fpSimdType = fpSimdType;
_reservedStackSize = reservedStackSize;
_hasCall = hasCall;
}
public int GetReservedStackOffset()
{
int gprCalleeSavedRegsCount = BitOperations.PopCount(_gprMask);
int fpSimdCalleeSavedRegsCount = BitOperations.PopCount(_fpSimdMask);
return (_hasCall ? 16 : 0) + Align16(gprCalleeSavedRegsCount * 8 + fpSimdCalleeSavedRegsCount * _fpSimdType.GetSizeInBytes());
}
public void WritePrologue(ref Assembler asm)
{
uint gprMask = _gprMask;
uint fpSimdMask = _fpSimdMask;
int gprCalleeSavedRegsCount = BitOperations.PopCount(gprMask);
int fpSimdCalleeSavedRegsCount = BitOperations.PopCount(fpSimdMask);
int reservedStackSize = Align16(_reservedStackSize);
int calleeSaveRegionSize = Align16(gprCalleeSavedRegsCount * 8 + fpSimdCalleeSavedRegsCount * _fpSimdType.GetSizeInBytes()) + reservedStackSize;
int offset = 0;
WritePrologueCalleeSavesPreIndexed(ref asm, ref gprMask, ref offset, calleeSaveRegionSize, OperandType.I64);
if (_fpSimdType == OperandType.V128 && (gprCalleeSavedRegsCount & 1) != 0)
{
offset += 8;
}
WritePrologueCalleeSavesPreIndexed(ref asm, ref fpSimdMask, ref offset, calleeSaveRegionSize, _fpSimdType);
if (_hasCall)
{
Operand rsp = Register(Assembler.SpRegister);
if (offset != 0 || calleeSaveRegionSize + 16 < Encodable9BitsOffsetLimit)
{
asm.StpRiPre(Register(FpRegister), Register(LrRegister), rsp, offset == 0 ? -(calleeSaveRegionSize + 16) : -16);
}
else
{
asm.Sub(rsp, rsp, new Operand(OperandKind.Constant, OperandType.I64, (ulong)calleeSaveRegionSize));
asm.StpRiPre(Register(FpRegister), Register(LrRegister), rsp, -16);
}
asm.MovSp(Register(FpRegister), rsp);
}
}
private static void WritePrologueCalleeSavesPreIndexed(
ref Assembler asm,
ref uint mask,
ref int offset,
int calleeSaveRegionSize,
OperandType type)
{
if ((BitOperations.PopCount(mask) & 1) != 0)
{
int reg = BitOperations.TrailingZeroCount(mask);
mask &= ~(1u << reg);
if (offset != 0)
{
asm.StrRiUn(Register(reg, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.StrRiPre(Register(reg, type), Register(Assembler.SpRegister), -calleeSaveRegionSize);
}
else
{
asm.Sub(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
asm.StrRiUn(Register(reg, type), Register(Assembler.SpRegister), 0);
}
offset += type.GetSizeInBytes();
}
while (mask != 0)
{
int reg = BitOperations.TrailingZeroCount(mask);
mask &= ~(1u << reg);
int reg2 = BitOperations.TrailingZeroCount(mask);
mask &= ~(1u << reg2);
if (offset != 0)
{
asm.StpRiUn(Register(reg, type), Register(reg2, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.StpRiPre(Register(reg, type), Register(reg2, type), Register(Assembler.SpRegister), -calleeSaveRegionSize);
}
else
{
asm.Sub(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
asm.StpRiUn(Register(reg, type), Register(reg2, type), Register(Assembler.SpRegister), 0);
}
offset += type.GetSizeInBytes() * 2;
}
}
public void WriteEpilogue(ref Assembler asm)
{
uint gprMask = _gprMask;
uint fpSimdMask = _fpSimdMask;
int gprCalleeSavedRegsCount = BitOperations.PopCount(gprMask);
int fpSimdCalleeSavedRegsCount = BitOperations.PopCount(fpSimdMask);
bool misalignedVector = _fpSimdType == OperandType.V128 && (gprCalleeSavedRegsCount & 1) != 0;
int offset = gprCalleeSavedRegsCount * 8 + fpSimdCalleeSavedRegsCount * _fpSimdType.GetSizeInBytes();
if (misalignedVector)
{
offset += 8;
}
int calleeSaveRegionSize = Align16(offset) + Align16(_reservedStackSize);
if (_hasCall)
{
Operand rsp = Register(Assembler.SpRegister);
if (offset != 0 || calleeSaveRegionSize + 16 < Encodable9BitsOffsetLimit)
{
asm.LdpRiPost(Register(FpRegister), Register(LrRegister), rsp, offset == 0 ? calleeSaveRegionSize + 16 : 16);
}
else
{
asm.LdpRiPost(Register(FpRegister), Register(LrRegister), rsp, 16);
asm.Add(rsp, rsp, new Operand(OperandKind.Constant, OperandType.I64, (ulong)calleeSaveRegionSize));
}
}
WriteEpilogueCalleeSavesPostIndexed(ref asm, ref fpSimdMask, ref offset, calleeSaveRegionSize, _fpSimdType);
if (misalignedVector)
{
offset -= 8;
}
WriteEpilogueCalleeSavesPostIndexed(ref asm, ref gprMask, ref offset, calleeSaveRegionSize, OperandType.I64);
}
private static void WriteEpilogueCalleeSavesPostIndexed(
ref Assembler asm,
ref uint mask,
ref int offset,
int calleeSaveRegionSize,
OperandType type)
{
while (mask != 0)
{
int reg = HighestBitSet(mask);
mask &= ~(1u << reg);
if (mask != 0)
{
int reg2 = HighestBitSet(mask);
mask &= ~(1u << reg2);
offset -= type.GetSizeInBytes() * 2;
if (offset != 0)
{
asm.LdpRiUn(Register(reg2, type), Register(reg, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.LdpRiPost(Register(reg2, type), Register(reg, type), Register(Assembler.SpRegister), calleeSaveRegionSize);
}
else
{
asm.LdpRiUn(Register(reg2, type), Register(reg, type), Register(Assembler.SpRegister), 0);
asm.Add(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
}
}
else
{
offset -= type.GetSizeInBytes();
if (offset != 0)
{
asm.LdrRiUn(Register(reg, type), Register(Assembler.SpRegister), offset);
}
else if (calleeSaveRegionSize < Encodable9BitsOffsetLimit)
{
asm.LdrRiPost(Register(reg, type), Register(Assembler.SpRegister), calleeSaveRegionSize);
}
else
{
asm.LdrRiUn(Register(reg, type), Register(Assembler.SpRegister), 0);
asm.Add(Register(Assembler.SpRegister), Register(Assembler.SpRegister), new Operand(OperandType.I64, (ulong)calleeSaveRegionSize));
}
}
}
}
private static int HighestBitSet(uint value)
{
return 31 - BitOperations.LeadingZeroCount(value);
}
private static Operand Register(int register, OperandType type = OperandType.I64)
{
return new Operand(register, RegisterType.Integer, type);
}
private static int Align16(int value)
{
return (value + 0xf) & ~0xf;
}
}
}

View file

@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
class StackWalker : IStackWalker
{
public IEnumerable<ulong> GetCallStack(IntPtr framePointer, IntPtr codeRegionStart, int codeRegionSize, IntPtr codeRegion2Start, int codeRegion2Size)
{
List<ulong> functionPointers = new();
while (true)
{
IntPtr functionPointer = Marshal.ReadIntPtr(framePointer, IntPtr.Size);
if ((functionPointer < codeRegionStart || functionPointer >= codeRegionStart + codeRegionSize) &&
(functionPointer < codeRegion2Start || functionPointer >= codeRegion2Start + codeRegion2Size))
{
break;
}
functionPointers.Add((ulong)functionPointer - 4);
framePointer = Marshal.ReadIntPtr(framePointer);
}
return functionPointers;
}
}
}

View file

@ -0,0 +1,120 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
namespace Ryujinx.Cpu.LightningJit.CodeGen.Arm64
{
class TailMerger
{
private enum BranchType
{
Conditional,
Unconditional,
}
private readonly List<(BranchType, int)> _branchPointers;
public TailMerger()
{
_branchPointers = new();
}
public void AddConditionalReturn(CodeWriter writer, in Assembler asm, ArmCondition returnCondition)
{
_branchPointers.Add((BranchType.Conditional, writer.InstructionPointer));
asm.B(returnCondition, 0);
}
public void AddConditionalZeroReturn(CodeWriter writer, in Assembler asm, Operand value)
{
_branchPointers.Add((BranchType.Conditional, writer.InstructionPointer));
asm.Cbz(value, 0);
}
public void AddUnconditionalReturn(CodeWriter writer, in Assembler asm)
{
_branchPointers.Add((BranchType.Unconditional, writer.InstructionPointer));
asm.B(0);
}
public void WriteReturn(CodeWriter writer, Action writeEpilogue)
{
if (_branchPointers.Count == 0)
{
return;
}
int targetIndex = writer.InstructionPointer;
int startIndex = _branchPointers.Count - 1;
if (startIndex >= 0 &&
_branchPointers[startIndex].Item1 == BranchType.Unconditional &&
_branchPointers[startIndex].Item2 == targetIndex - 1)
{
// Remove the last branch if it is redundant.
writer.RemoveLastInstruction();
startIndex--;
targetIndex--;
}
Assembler asm = new(writer);
writeEpilogue();
asm.Ret();
for (int i = startIndex; i >= 0; i--)
{
(BranchType type, int branchIndex) = _branchPointers[i];
uint encoding = writer.ReadInstructionAt(branchIndex);
int delta = targetIndex - branchIndex;
if (type == BranchType.Conditional)
{
uint branchMask = 0x7ffff;
int branchMax = (int)(branchMask + 1) / 2;
if (delta >= -branchMax && delta < branchMax)
{
writer.WriteInstructionAt(branchIndex, (encoding & ~(branchMask << 5)) | (uint)((delta & branchMask) << 5));
}
else
{
// If the branch target is too far away, we use a regular unconditional branch
// instruction instead which has a much higher range.
// We branch directly to the end of the function, where we put the conditional branch,
// and then branch back to the next instruction or return the branch target depending
// on the branch being taken or not.
delta = writer.InstructionPointer - branchIndex;
uint branchInst = 0x14000000u | ((uint)delta & 0x3ffffff);
Debug.Assert(ExtractSImm26Times4(branchInst) == delta * 4);
writer.WriteInstructionAt(branchIndex, branchInst);
int movedBranchIndex = writer.InstructionPointer;
writer.WriteInstruction(0u); // Placeholder
asm.B((branchIndex + 1 - writer.InstructionPointer) * 4);
delta = targetIndex - movedBranchIndex;
writer.WriteInstructionAt(movedBranchIndex, (encoding & ~(branchMask << 5)) | (uint)((delta & branchMask) << 5));
}
}
else
{
Debug.Assert(type == BranchType.Unconditional);
writer.WriteInstructionAt(branchIndex, (encoding & ~0x3ffffffu) | (uint)(delta & 0x3ffffff));
}
}
}
private static int ExtractSImm26Times4(uint encoding)
{
return (int)(encoding << 6) >> 4;
}
}
}