EXPERIMENTAL: Metal backend (#441)

This is not a continuation of the Metal backend; this is simply bringing
the branch up to date and merging it as-is behind an experiment.

---------

Co-authored-by: Isaac Marovitz <isaacryu@icloud.com>
Co-authored-by: Samuliak <samuliak77@gmail.com>
Co-authored-by: SamoZ256 <96914946+SamoZ256@users.noreply.github.com>
Co-authored-by: Isaac Marovitz <42140194+IsaacMarovitz@users.noreply.github.com>
Co-authored-by: riperiperi <rhy3756547@hotmail.com>
Co-authored-by: Gabriel A <gab.dark.100@gmail.com>
This commit is contained in:
Evan Husted 2024-12-24 00:55:16 -06:00 committed by GitHub
parent 3094df54dd
commit 852823104f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
131 changed files with 14992 additions and 140 deletions

View file

@ -0,0 +1,108 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System.Text;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
class CodeGenContext
{
public const string Tab = " ";
// The number of additional arguments that every function (except for the main one) must have (for instance support_buffer)
public const int AdditionalArgCount = 2;
public StructuredFunction CurrentFunction { get; set; }
public StructuredProgramInfo Info { get; }
public AttributeUsage AttributeUsage { get; }
public ShaderDefinitions Definitions { get; }
public ShaderProperties Properties { get; }
public HostCapabilities HostCapabilities { get; }
public ILogger Logger { get; }
public TargetApi TargetApi { get; }
public OperandManager OperandManager { get; }
private readonly StringBuilder _sb;
private int _level;
private string _indentation;
public CodeGenContext(StructuredProgramInfo info, CodeGenParameters parameters)
{
Info = info;
AttributeUsage = parameters.AttributeUsage;
Definitions = parameters.Definitions;
Properties = parameters.Properties;
HostCapabilities = parameters.HostCapabilities;
Logger = parameters.Logger;
TargetApi = parameters.TargetApi;
OperandManager = new OperandManager();
_sb = new StringBuilder();
}
public void AppendLine()
{
_sb.AppendLine();
}
public void AppendLine(string str)
{
_sb.AppendLine(_indentation + str);
}
public string GetCode()
{
return _sb.ToString();
}
public void EnterScope(string prefix = "")
{
AppendLine(prefix + "{");
_level++;
UpdateIndentation();
}
public void LeaveScope(string suffix = "")
{
if (_level == 0)
{
return;
}
_level--;
UpdateIndentation();
AppendLine("}" + suffix);
}
public StructuredFunction GetFunction(int id)
{
return Info.Functions[id];
}
private void UpdateIndentation()
{
_indentation = GetIndentation(_level);
}
private static string GetIndentation(int level)
{
string indentation = string.Empty;
for (int index = 0; index < level; index++)
{
indentation += Tab;
}
return indentation;
}
}
}

View file

@ -0,0 +1,578 @@
using Ryujinx.Common;
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Collections.Generic;
using System.Linq;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class Declarations
{
/*
* Description of MSL Binding Model
*
* There are a few fundamental differences between how GLSL and MSL handle I/O.
* This comment will set out to describe the reasons why things are done certain ways
* and to describe the overall binding model that we're striving for here.
*
* Main I/O Structs
*
* Each stage has a main input and output struct (if applicable) labeled as [Stage][In/Out], i.e VertexIn.
* Every field within these structs is labeled with an [[attribute(n)]] property,
* and the overall struct is labeled with [[stage_in]] for input structs, and defined as the
* output type of the main shader function for the output struct. This struct also contains special
* attribute-based properties like [[position]] that would be "built-ins" in a GLSL context.
*
* These structs are passed as inputs to all inline functions due to containing "built-ins"
* that inline functions assume access to.
*
* Vertex & Zero Buffers
*
* Binding indices 0-16 are reserved for vertex buffers, and binding 18 is reserved for the zero buffer.
*
* Uniforms & Storage Buffers
*
* Uniforms and storage buffers are tightly packed into their respective argument buffers
* (effectively ignoring binding indices at shader level), with each pointer to the corresponding
* struct that defines the layout and fields of these buffers (usually just a single data array), laid
* out one after the other in ascending order of their binding index.
*
* The uniforms argument buffer is always bound at a fixed index of 20.
* The storage buffers argument buffer is always bound at a fixed index of 21.
*
* These structs are passed as inputs to all inline functions as in GLSL or SPIRV,
* uniforms and storage buffers would be globals, and inline functions assume access to these buffers.
*
* Samplers & Textures
*
* Metal does not have a combined image sampler like sampler2D in GLSL, as a result we need to bind
* an individual texture and a sampler object for each instance of a combined image sampler.
* Samplers and textures are bound in a shared argument buffer. This argument buffer is tightly packed
* (effectively ignoring binding indices at shader level), with texture and their samplers (if present)
* laid out one after the other in ascending order of their binding index.
*
* The samplers and textures argument buffer is always bound at a fixed index of 22.
*
*/
public static int[] Declare(CodeGenContext context, StructuredProgramInfo info)
{
// TODO: Re-enable this warning
context.AppendLine("#pragma clang diagnostic ignored \"-Wunused-variable\"");
context.AppendLine();
context.AppendLine("#include <metal_stdlib>");
context.AppendLine("#include <simd/simd.h>");
context.AppendLine();
context.AppendLine("using namespace metal;");
context.AppendLine();
var fsi = (info.HelperFunctionsMask & HelperFunctionsMask.FSI) != 0;
DeclareInputAttributes(context, info.IoDefinitions.Where(x => IsUserDefined(x, StorageKind.Input)));
context.AppendLine();
DeclareOutputAttributes(context, info.IoDefinitions.Where(x => x.StorageKind == StorageKind.Output));
context.AppendLine();
DeclareBufferStructures(context, context.Properties.ConstantBuffers.Values.OrderBy(x => x.Binding).ToArray(), true, fsi);
DeclareBufferStructures(context, context.Properties.StorageBuffers.Values.OrderBy(x => x.Binding).ToArray(), false, fsi);
// We need to declare each set as a new struct
var textureDefinitions = context.Properties.Textures.Values
.GroupBy(x => x.Set)
.ToDictionary(x => x.Key, x => x.OrderBy(y => y.Binding).ToArray());
var imageDefinitions = context.Properties.Images.Values
.GroupBy(x => x.Set)
.ToDictionary(x => x.Key, x => x.OrderBy(y => y.Binding).ToArray());
var textureSets = textureDefinitions.Keys.ToArray();
var imageSets = imageDefinitions.Keys.ToArray();
var sets = textureSets.Union(imageSets).ToArray();
foreach (var set in textureDefinitions)
{
DeclareTextures(context, set.Value, set.Key);
}
foreach (var set in imageDefinitions)
{
DeclareImages(context, set.Value, set.Key, fsi);
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.FindLSB) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindLSB.metal");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.FindMSBS32) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBS32.metal");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.FindMSBU32) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/FindMSBU32.metal");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/SwizzleAdd.metal");
}
if ((info.HelperFunctionsMask & HelperFunctionsMask.Precise) != 0)
{
AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Msl/HelperFunctions/Precise.metal");
}
return sets;
}
static bool IsUserDefined(IoDefinition ioDefinition, StorageKind storageKind)
{
return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined;
}
public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false)
{
if (isMainFunc)
{
// TODO: Support OaIndexing
if (context.Definitions.IaIndexing)
{
context.EnterScope($"array<float4, {Constants.MaxAttributes}> {Defaults.IAttributePrefix} = ");
for (int i = 0; i < Constants.MaxAttributes; i++)
{
context.AppendLine($"in.{Defaults.IAttributePrefix}{i},");
}
context.LeaveScope(";");
}
DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false);
DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true);
switch (stage)
{
case ShaderStage.Vertex:
context.AppendLine("VertexOut out = {};");
// TODO: Only add if necessary
context.AppendLine("uint instance_index = instance_id + base_instance;");
break;
case ShaderStage.Fragment:
context.AppendLine("FragmentOut out = {};");
break;
}
// TODO: Only add if necessary
if (stage != ShaderStage.Compute)
{
// MSL does not give us access to [[thread_index_in_simdgroup]]
// outside compute. But we may still need to provide this value in frag/vert.
context.AppendLine("uint thread_index_in_simdgroup = simd_prefix_exclusive_sum(1);");
}
}
foreach (AstOperand decl in function.Locals)
{
string name = context.OperandManager.DeclareLocal(decl);
context.AppendLine(GetVarTypeName(decl.VarType) + " " + name + ";");
}
}
public static string GetVarTypeName(AggregateType type, bool atomic = false)
{
var s32 = atomic ? "atomic_int" : "int";
var u32 = atomic ? "atomic_uint" : "uint";
return type switch
{
AggregateType.Void => "void",
AggregateType.Bool => "bool",
AggregateType.FP32 => "float",
AggregateType.S32 => s32,
AggregateType.U32 => u32,
AggregateType.Vector2 | AggregateType.Bool => "bool2",
AggregateType.Vector2 | AggregateType.FP32 => "float2",
AggregateType.Vector2 | AggregateType.S32 => "int2",
AggregateType.Vector2 | AggregateType.U32 => "uint2",
AggregateType.Vector3 | AggregateType.Bool => "bool3",
AggregateType.Vector3 | AggregateType.FP32 => "float3",
AggregateType.Vector3 | AggregateType.S32 => "int3",
AggregateType.Vector3 | AggregateType.U32 => "uint3",
AggregateType.Vector4 | AggregateType.Bool => "bool4",
AggregateType.Vector4 | AggregateType.FP32 => "float4",
AggregateType.Vector4 | AggregateType.S32 => "int4",
AggregateType.Vector4 | AggregateType.U32 => "uint4",
_ => throw new ArgumentException($"Invalid variable type \"{type}\"."),
};
}
private static void DeclareMemories(CodeGenContext context, IEnumerable<MemoryDefinition> memories, bool isShared)
{
string prefix = isShared ? "threadgroup " : string.Empty;
foreach (var memory in memories)
{
string arraySize = "";
if ((memory.Type & AggregateType.Array) != 0)
{
arraySize = $"[{memory.ArrayLength}]";
}
var typeName = GetVarTypeName(memory.Type & ~AggregateType.Array);
context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};");
}
}
private static void DeclareBufferStructures(CodeGenContext context, BufferDefinition[] buffers, bool constant, bool fsi)
{
var name = constant ? "ConstantBuffers" : "StorageBuffers";
var addressSpace = constant ? "constant" : "device";
string[] bufferDec = new string[buffers.Length];
for (int i = 0; i < buffers.Length; i++)
{
BufferDefinition buffer = buffers[i];
var needsPadding = buffer.Layout == BufferLayout.Std140;
string fsiSuffix = !constant && fsi ? " [[raster_order_group(0)]]" : "";
bufferDec[i] = $"{addressSpace} {Defaults.StructPrefix}_{buffer.Name}* {buffer.Name}{fsiSuffix};";
context.AppendLine($"struct {Defaults.StructPrefix}_{buffer.Name}");
context.EnterScope();
foreach (StructureField field in buffer.Type.Fields)
{
var type = field.Type;
type |= (needsPadding && (field.Type & AggregateType.Array) != 0)
? AggregateType.Vector4
: AggregateType.Invalid;
type &= ~AggregateType.Array;
string typeName = GetVarTypeName(type);
string arraySuffix = "";
if (field.Type.HasFlag(AggregateType.Array))
{
if (field.ArrayLength > 0)
{
arraySuffix = $"[{field.ArrayLength}]";
}
else
{
// Probably UB, but this is the approach that MVK takes
arraySuffix = "[1]";
}
}
context.AppendLine($"{typeName} {field.Name}{arraySuffix};");
}
context.LeaveScope(";");
context.AppendLine();
}
context.AppendLine($"struct {name}");
context.EnterScope();
foreach (var declaration in bufferDec)
{
context.AppendLine(declaration);
}
context.LeaveScope(";");
context.AppendLine();
}
private static void DeclareTextures(CodeGenContext context, TextureDefinition[] textures, int set)
{
var setName = GetNameForSet(set);
context.AppendLine($"struct {setName}");
context.EnterScope();
List<string> textureDec = [];
foreach (TextureDefinition texture in textures)
{
if (texture.Type != SamplerType.None)
{
var textureTypeName = texture.Type.ToMslTextureType(texture.Format.GetComponentType());
if (texture.ArrayLength > 1)
{
textureTypeName = $"array<{textureTypeName}, {texture.ArrayLength}>";
}
textureDec.Add($"{textureTypeName} tex_{texture.Name};");
}
if (!texture.Separate && texture.Type != SamplerType.TextureBuffer)
{
var samplerType = "sampler";
if (texture.ArrayLength > 1)
{
samplerType = $"array<{samplerType}, {texture.ArrayLength}>";
}
textureDec.Add($"{samplerType} samp_{texture.Name};");
}
}
foreach (var declaration in textureDec)
{
context.AppendLine(declaration);
}
context.LeaveScope(";");
context.AppendLine();
}
private static void DeclareImages(CodeGenContext context, TextureDefinition[] images, int set, bool fsi)
{
var setName = GetNameForSet(set);
context.AppendLine($"struct {setName}");
context.EnterScope();
string[] imageDec = new string[images.Length];
for (int i = 0; i < images.Length; i++)
{
TextureDefinition image = images[i];
var imageTypeName = image.Type.ToMslTextureType(image.Format.GetComponentType(), true);
if (image.ArrayLength > 1)
{
imageTypeName = $"array<{imageTypeName}, {image.ArrayLength}>";
}
string fsiSuffix = fsi ? " [[raster_order_group(0)]]" : "";
imageDec[i] = $"{imageTypeName} {image.Name}{fsiSuffix};";
}
foreach (var declaration in imageDec)
{
context.AppendLine(declaration);
}
context.LeaveScope(";");
context.AppendLine();
}
private static void DeclareInputAttributes(CodeGenContext context, IEnumerable<IoDefinition> inputs)
{
if (context.Definitions.Stage == ShaderStage.Compute)
{
return;
}
switch (context.Definitions.Stage)
{
case ShaderStage.Vertex:
context.AppendLine("struct VertexIn");
break;
case ShaderStage.Fragment:
context.AppendLine("struct FragmentIn");
break;
}
context.EnterScope();
if (context.Definitions.Stage == ShaderStage.Fragment)
{
// TODO: check if it's needed
context.AppendLine("float4 position [[position, invariant]];");
context.AppendLine("bool front_facing [[front_facing]];");
context.AppendLine("float2 point_coord [[point_coord]];");
context.AppendLine("uint primitive_id [[primitive_id]];");
}
if (context.Definitions.IaIndexing)
{
// MSL does not support arrays in stage I/O
// We need to use the SPIRV-Cross workaround
for (int i = 0; i < Constants.MaxAttributes; i++)
{
var suffix = context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{i})]]" : $"[[attribute({i})]]";
context.AppendLine($"float4 {Defaults.IAttributePrefix}{i} {suffix};");
}
}
if (inputs.Any())
{
foreach (var ioDefinition in inputs.OrderBy(x => x.Location))
{
if (context.Definitions.IaIndexing && ioDefinition.IoVariable == IoVariable.UserDefined)
{
continue;
}
string iq = string.Empty;
if (context.Definitions.Stage == ShaderStage.Fragment)
{
iq = context.Definitions.ImapTypes[ioDefinition.Location].GetFirstUsedType() switch
{
PixelImap.Constant => "[[flat]] ",
PixelImap.ScreenLinear => "[[center_no_perspective]] ",
_ => string.Empty,
};
}
string type = ioDefinition.IoVariable switch
{
// IoVariable.Position => "float4",
IoVariable.GlobalId => "uint3",
IoVariable.VertexId => "uint",
IoVariable.VertexIndex => "uint",
// IoVariable.PointCoord => "float2",
_ => GetVarTypeName(context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false))
};
string name = ioDefinition.IoVariable switch
{
// IoVariable.Position => "position",
IoVariable.GlobalId => "global_id",
IoVariable.VertexId => "vertex_id",
IoVariable.VertexIndex => "vertex_index",
// IoVariable.PointCoord => "point_coord",
_ => $"{Defaults.IAttributePrefix}{ioDefinition.Location}"
};
string suffix = ioDefinition.IoVariable switch
{
// IoVariable.Position => "[[position, invariant]]",
IoVariable.GlobalId => "[[thread_position_in_grid]]",
IoVariable.VertexId => "[[vertex_id]]",
// TODO: Avoid potential redeclaration
IoVariable.VertexIndex => "[[vertex_id]]",
// IoVariable.PointCoord => "[[point_coord]]",
IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]",
_ => ""
};
context.AppendLine($"{type} {name} {iq}{suffix};");
}
}
context.LeaveScope(";");
}
private static void DeclareOutputAttributes(CodeGenContext context, IEnumerable<IoDefinition> outputs)
{
switch (context.Definitions.Stage)
{
case ShaderStage.Vertex:
context.AppendLine("struct VertexOut");
break;
case ShaderStage.Fragment:
context.AppendLine("struct FragmentOut");
break;
case ShaderStage.Compute:
context.AppendLine("struct KernelOut");
break;
}
context.EnterScope();
if (context.Definitions.OaIndexing)
{
// MSL does not support arrays in stage I/O
// We need to use the SPIRV-Cross workaround
for (int i = 0; i < Constants.MaxAttributes; i++)
{
context.AppendLine($"float4 {Defaults.OAttributePrefix}{i} [[user(loc{i})]];");
}
}
if (outputs.Any())
{
outputs = outputs.OrderBy(x => x.Location);
if (context.Definitions.Stage == ShaderStage.Fragment && context.Definitions.DualSourceBlend)
{
IoDefinition firstOutput = outputs.ElementAtOrDefault(0);
IoDefinition secondOutput = outputs.ElementAtOrDefault(1);
var type1 = GetVarTypeName(context.Definitions.GetFragmentOutputColorType(firstOutput.Location));
var type2 = GetVarTypeName(context.Definitions.GetFragmentOutputColorType(secondOutput.Location));
var name1 = $"color{firstOutput.Location}";
var name2 = $"color{firstOutput.Location + 1}";
context.AppendLine($"{type1} {name1} [[color({firstOutput.Location}), index(0)]];");
context.AppendLine($"{type2} {name2} [[color({firstOutput.Location}), index(1)]];");
outputs = outputs.Skip(2);
}
foreach (var ioDefinition in outputs)
{
if (context.Definitions.OaIndexing && ioDefinition.IoVariable == IoVariable.UserDefined)
{
continue;
}
string type = ioDefinition.IoVariable switch
{
IoVariable.Position => "float4",
IoVariable.PointSize => "float",
IoVariable.FragmentOutputColor => GetVarTypeName(context.Definitions.GetFragmentOutputColorType(ioDefinition.Location)),
IoVariable.FragmentOutputDepth => "float",
IoVariable.ClipDistance => "float",
_ => GetVarTypeName(context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: true))
};
string name = ioDefinition.IoVariable switch
{
IoVariable.Position => "position",
IoVariable.PointSize => "point_size",
IoVariable.FragmentOutputColor => $"color{ioDefinition.Location}",
IoVariable.FragmentOutputDepth => "depth",
IoVariable.ClipDistance => "clip_distance",
_ => $"{Defaults.OAttributePrefix}{ioDefinition.Location}"
};
string suffix = ioDefinition.IoVariable switch
{
IoVariable.Position => "[[position, invariant]]",
IoVariable.PointSize => "[[point_size]]",
IoVariable.UserDefined => $"[[user(loc{ioDefinition.Location})]]",
IoVariable.FragmentOutputColor => $"[[color({ioDefinition.Location})]]",
IoVariable.FragmentOutputDepth => "[[depth(any)]]",
IoVariable.ClipDistance => $"[[clip_distance]][{Defaults.TotalClipDistances}]",
_ => ""
};
context.AppendLine($"{type} {name} {suffix};");
}
}
context.LeaveScope(";");
}
private static void AppendHelperFunction(CodeGenContext context, string filename)
{
string code = EmbeddedResources.ReadAllText(filename);
code = code.Replace("\t", CodeGenContext.Tab);
context.AppendLine(code);
context.AppendLine();
}
public static string GetNameForSet(int set, bool forVar = false)
{
return (uint)set switch
{
Defaults.TexturesSetIndex => forVar ? "textures" : "Textures",
Defaults.ImagesSetIndex => forVar ? "images" : "Images",
_ => $"{(forVar ? "set" : "Set")}{set}"
};
}
}
}

View file

@ -0,0 +1,34 @@
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class Defaults
{
public const string LocalNamePrefix = "temp";
public const string PerPatchAttributePrefix = "patchAttr";
public const string IAttributePrefix = "inAttr";
public const string OAttributePrefix = "outAttr";
public const string StructPrefix = "struct";
public const string ArgumentNamePrefix = "a";
public const string UndefinedName = "0";
public const int MaxVertexBuffers = 16;
public const uint ZeroBufferIndex = MaxVertexBuffers;
public const uint BaseSetIndex = MaxVertexBuffers + 1;
public const uint ConstantBuffersIndex = BaseSetIndex;
public const uint StorageBuffersIndex = BaseSetIndex + 1;
public const uint TexturesIndex = BaseSetIndex + 2;
public const uint ImagesIndex = BaseSetIndex + 3;
public const uint ConstantBuffersSetIndex = 0;
public const uint StorageBuffersSetIndex = 1;
public const uint TexturesSetIndex = 2;
public const uint ImagesSetIndex = 3;
public const int TotalClipDistances = 8;
}
}

View file

@ -0,0 +1,5 @@
template<typename T>
inline T findLSB(T x)
{
return select(ctz(x), T(-1), x == T(0));
}

View file

@ -0,0 +1,5 @@
template<typename T>
inline T findMSBS32(T x)
{
return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));
}

View file

@ -0,0 +1,6 @@
template<typename T>
inline T findMSBU32(T x)
{
T v = select(x, T(-1) - x, x < T(0));
return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));
}

View file

@ -0,0 +1,10 @@
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class HelperFunctionNames
{
public static string FindLSB = "findLSB";
public static string FindMSBS32 = "findMSBS32";
public static string FindMSBU32 = "findMSBU32";
public static string SwizzleAdd = "swizzleAdd";
}
}

View file

@ -0,0 +1,14 @@
template<typename T>
[[clang::optnone]] T PreciseFAdd(T l, T r) {
return fma(T(1), l, r);
}
template<typename T>
[[clang::optnone]] T PreciseFSub(T l, T r) {
return fma(T(-1), r, l);
}
template<typename T>
[[clang::optnone]] T PreciseFMul(T l, T r) {
return fma(l, r, T(0));
}

View file

@ -0,0 +1,7 @@
float swizzleAdd(float x, float y, int mask, uint thread_index_in_simdgroup)
{
float4 xLut = float4(1.0, -1.0, 1.0, 0.0);
float4 yLut = float4(1.0, 1.0, -1.0, 1.0);
int lutIdx = (mask >> (int(thread_index_in_simdgroup & 3u) * 2)) & 3;
return x * xLut[lutIdx] + y * yLut[lutIdx];
}

View file

@ -0,0 +1,185 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBallot;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenBarrier;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenVector;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGen
{
public static string GetExpression(CodeGenContext context, IAstNode node)
{
if (node is AstOperation operation)
{
return GetExpression(context, operation);
}
else if (node is AstOperand operand)
{
return context.OperandManager.GetExpression(context, operand);
}
throw new ArgumentException($"Invalid node type \"{node?.GetType().Name ?? "null"}\".");
}
private static string GetExpression(CodeGenContext context, AstOperation operation)
{
Instruction inst = operation.Inst;
InstInfo info = GetInstructionInfo(inst);
if ((info.Type & InstType.Call) != 0)
{
bool atomic = (info.Type & InstType.Atomic) != 0;
int arity = (int)(info.Type & InstType.ArityMask);
StringBuilder builder = new();
if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory))
{
AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32
? AggregateType.S32
: AggregateType.U32;
var shared = operation.StorageKind == StorageKind.SharedMemory;
builder.Append($"({(shared ? "threadgroup" : "device")} {Declarations.GetVarTypeName(dstType, true)}*)&{GenerateLoadOrStore(context, operation, isStore: false)}");
for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++)
{
builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}, memory_order_relaxed");
}
}
else
{
for (int argIndex = 0; argIndex < arity; argIndex++)
{
if (argIndex != 0)
{
builder.Append(", ");
}
AggregateType dstType = GetSrcVarType(inst, argIndex);
builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType));
}
if ((operation.Inst & Instruction.Mask) == Instruction.SwizzleAdd)
{
// SwizzleAdd takes one last argument, the thread_index_in_simdgroup
builder.Append(", thread_index_in_simdgroup");
}
}
return $"{info.OpName}({builder})";
}
else if ((info.Type & InstType.Op) != 0)
{
string op = info.OpName;
if (inst == Instruction.Return && operation.SourcesCount != 0)
{
return $"{op} {GetSourceExpr(context, operation.GetSource(0), context.CurrentFunction.ReturnType)}";
}
if (inst == Instruction.Return && context.Definitions.Stage is ShaderStage.Vertex or ShaderStage.Fragment)
{
return $"{op} out";
}
int arity = (int)(info.Type & InstType.ArityMask);
string[] expr = new string[arity];
for (int index = 0; index < arity; index++)
{
IAstNode src = operation.GetSource(index);
string srcExpr = GetSourceExpr(context, src, GetSrcVarType(inst, index));
bool isLhs = arity == 2 && index == 0;
expr[index] = Enclose(srcExpr, src, inst, info, isLhs);
}
switch (arity)
{
case 0:
return op;
case 1:
return op + expr[0];
case 2:
if (operation.ForcePrecise)
{
var func = (inst & Instruction.Mask) switch
{
Instruction.Add => "PreciseFAdd",
Instruction.Subtract => "PreciseFSub",
Instruction.Multiply => "PreciseFMul",
};
return $"{func}({expr[0]}, {expr[1]})";
}
return $"{expr[0]} {op} {expr[1]}";
case 3:
return $"{expr[0]} {op[0]} {expr[1]} {op[1]} {expr[2]}";
}
}
else if ((info.Type & InstType.Special) != 0)
{
switch (inst & Instruction.Mask)
{
case Instruction.Ballot:
return Ballot(context, operation);
case Instruction.Call:
return Call(context, operation);
case Instruction.FSIBegin:
case Instruction.FSIEnd:
return "// FSI implemented with raster order groups in MSL";
case Instruction.GroupMemoryBarrier:
case Instruction.MemoryBarrier:
case Instruction.Barrier:
return Barrier(context, operation);
case Instruction.ImageLoad:
case Instruction.ImageStore:
case Instruction.ImageAtomic:
return ImageLoadOrStore(context, operation);
case Instruction.Load:
return Load(context, operation);
case Instruction.Lod:
return Lod(context, operation);
case Instruction.Store:
return Store(context, operation);
case Instruction.TextureSample:
return TextureSample(context, operation);
case Instruction.TextureQuerySamples:
return TextureQuerySamples(context, operation);
case Instruction.TextureQuerySize:
return TextureQuerySize(context, operation);
case Instruction.PackHalf2x16:
return PackHalf2x16(context, operation);
case Instruction.UnpackHalf2x16:
return UnpackHalf2x16(context, operation);
case Instruction.VectorExtract:
return VectorExtract(context, operation);
case Instruction.VoteAllEqual:
return VoteAllEqual(context, operation);
}
}
// TODO: Return this to being an error
return $"Unexpected instruction type \"{info.Type}\".";
}
}
}

View file

@ -0,0 +1,30 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenBallot
{
public static string Ballot(CodeGenContext context, AstOperation operation)
{
AggregateType dstType = GetSrcVarType(operation.Inst, 0);
string arg = GetSourceExpr(context, operation.GetSource(0), dstType);
char component = "xyzw"[operation.Index];
return $"uint4(as_type<uint2>((simd_vote::vote_t)simd_ballot({arg})), 0, 0).{component}";
}
public static string VoteAllEqual(CodeGenContext context, AstOperation operation)
{
AggregateType dstType = GetSrcVarType(operation.Inst, 0);
string arg = GetSourceExpr(context, operation.GetSource(0), dstType);
return $"simd_all({arg}) || !simd_any({arg})";
}
}
}

View file

@ -0,0 +1,15 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenBarrier
{
public static string Barrier(CodeGenContext context, AstOperation operation)
{
var device = (operation.Inst & Instruction.Mask) == Instruction.MemoryBarrier;
return $"threadgroup_barrier(mem_flags::mem_{(device ? "device" : "threadgroup")})";
}
}
}

View file

@ -0,0 +1,60 @@
using Ryujinx.Graphics.Shader.StructuredIr;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenCall
{
public static string Call(CodeGenContext context, AstOperation operation)
{
AstOperand funcId = (AstOperand)operation.GetSource(0);
var function = context.GetFunction(funcId.Value);
int argCount = operation.SourcesCount - 1;
int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
bool needsThreadIndex = false;
// TODO: Replace this with a proper flag
if (function.Name.Contains("Shuffle"))
{
needsThreadIndex = true;
additionalArgCount++;
}
string[] args = new string[argCount + additionalArgCount];
// Additional arguments
if (context.Definitions.Stage != ShaderStage.Compute)
{
args[0] = "in";
args[1] = "constant_buffers";
args[2] = "storage_buffers";
if (needsThreadIndex)
{
args[3] = "thread_index_in_simdgroup";
}
}
else
{
args[0] = "constant_buffers";
args[1] = "storage_buffers";
if (needsThreadIndex)
{
args[2] = "thread_index_in_simdgroup";
}
}
int argIndex = additionalArgCount;
for (int i = 0; i < argCount; i++)
{
args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), function.GetArgumentType(i));
}
return $"{function.Name}({string.Join(", ", args)})";
}
}
}

View file

@ -0,0 +1,222 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.TypeConversion;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenHelper
{
private static readonly InstInfo[] _infoTable;
static InstGenHelper()
{
_infoTable = new InstInfo[(int)Instruction.Count];
#pragma warning disable IDE0055 // Disable formatting
Add(Instruction.AtomicAdd, InstType.AtomicBinary, "atomic_fetch_add_explicit");
Add(Instruction.AtomicAnd, InstType.AtomicBinary, "atomic_fetch_and_explicit");
Add(Instruction.AtomicCompareAndSwap, InstType.AtomicBinary, "atomic_compare_exchange_weak_explicit");
Add(Instruction.AtomicMaxU32, InstType.AtomicBinary, "atomic_fetch_max_explicit");
Add(Instruction.AtomicMinU32, InstType.AtomicBinary, "atomic_fetch_min_explicit");
Add(Instruction.AtomicOr, InstType.AtomicBinary, "atomic_fetch_or_explicit");
Add(Instruction.AtomicSwap, InstType.AtomicBinary, "atomic_exchange_explicit");
Add(Instruction.AtomicXor, InstType.AtomicBinary, "atomic_fetch_xor_explicit");
Add(Instruction.Absolute, InstType.CallUnary, "abs");
Add(Instruction.Add, InstType.OpBinaryCom, "+", 2);
Add(Instruction.Ballot, InstType.Special);
Add(Instruction.Barrier, InstType.Special);
Add(Instruction.BitCount, InstType.CallUnary, "popcount");
Add(Instruction.BitfieldExtractS32, InstType.CallTernary, "extract_bits");
Add(Instruction.BitfieldExtractU32, InstType.CallTernary, "extract_bits");
Add(Instruction.BitfieldInsert, InstType.CallQuaternary, "insert_bits");
Add(Instruction.BitfieldReverse, InstType.CallUnary, "reverse_bits");
Add(Instruction.BitwiseAnd, InstType.OpBinaryCom, "&", 6);
Add(Instruction.BitwiseExclusiveOr, InstType.OpBinaryCom, "^", 7);
Add(Instruction.BitwiseNot, InstType.OpUnary, "~", 0);
Add(Instruction.BitwiseOr, InstType.OpBinaryCom, "|", 8);
Add(Instruction.Call, InstType.Special);
Add(Instruction.Ceiling, InstType.CallUnary, "ceil");
Add(Instruction.Clamp, InstType.CallTernary, "clamp");
Add(Instruction.ClampU32, InstType.CallTernary, "clamp");
Add(Instruction.CompareEqual, InstType.OpBinaryCom, "==", 5);
Add(Instruction.CompareGreater, InstType.OpBinary, ">", 4);
Add(Instruction.CompareGreaterOrEqual, InstType.OpBinary, ">=", 4);
Add(Instruction.CompareGreaterOrEqualU32, InstType.OpBinary, ">=", 4);
Add(Instruction.CompareGreaterU32, InstType.OpBinary, ">", 4);
Add(Instruction.CompareLess, InstType.OpBinary, "<", 4);
Add(Instruction.CompareLessOrEqual, InstType.OpBinary, "<=", 4);
Add(Instruction.CompareLessOrEqualU32, InstType.OpBinary, "<=", 4);
Add(Instruction.CompareLessU32, InstType.OpBinary, "<", 4);
Add(Instruction.CompareNotEqual, InstType.OpBinaryCom, "!=", 5);
Add(Instruction.ConditionalSelect, InstType.OpTernary, "?:", 12);
Add(Instruction.ConvertFP32ToFP64, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertFP64ToFP32, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertFP32ToS32, InstType.CallUnary, "int");
Add(Instruction.ConvertFP32ToU32, InstType.CallUnary, "uint");
Add(Instruction.ConvertFP64ToS32, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertFP64ToU32, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertS32ToFP32, InstType.CallUnary, "float");
Add(Instruction.ConvertS32ToFP64, 0); // MSL does not have a 64-bit FP
Add(Instruction.ConvertU32ToFP32, InstType.CallUnary, "float");
Add(Instruction.ConvertU32ToFP64, 0); // MSL does not have a 64-bit FP
Add(Instruction.Cosine, InstType.CallUnary, "cos");
Add(Instruction.Ddx, InstType.CallUnary, "dfdx");
Add(Instruction.Ddy, InstType.CallUnary, "dfdy");
Add(Instruction.Discard, InstType.CallNullary, "discard_fragment");
Add(Instruction.Divide, InstType.OpBinary, "/", 1);
Add(Instruction.EmitVertex, 0); // MSL does not have geometry shaders
Add(Instruction.EndPrimitive, 0); // MSL does not have geometry shaders
Add(Instruction.ExponentB2, InstType.CallUnary, "exp2");
Add(Instruction.FSIBegin, InstType.Special);
Add(Instruction.FSIEnd, InstType.Special);
Add(Instruction.FindLSB, InstType.CallUnary, HelperFunctionNames.FindLSB);
Add(Instruction.FindMSBS32, InstType.CallUnary, HelperFunctionNames.FindMSBS32);
Add(Instruction.FindMSBU32, InstType.CallUnary, HelperFunctionNames.FindMSBU32);
Add(Instruction.Floor, InstType.CallUnary, "floor");
Add(Instruction.FusedMultiplyAdd, InstType.CallTernary, "fma");
Add(Instruction.GroupMemoryBarrier, InstType.Special);
Add(Instruction.ImageLoad, InstType.Special);
Add(Instruction.ImageStore, InstType.Special);
Add(Instruction.ImageAtomic, InstType.Special); // Metal 3.1+
Add(Instruction.IsNan, InstType.CallUnary, "isnan");
Add(Instruction.Load, InstType.Special);
Add(Instruction.Lod, InstType.Special);
Add(Instruction.LogarithmB2, InstType.CallUnary, "log2");
Add(Instruction.LogicalAnd, InstType.OpBinaryCom, "&&", 9);
Add(Instruction.LogicalExclusiveOr, InstType.OpBinaryCom, "^", 10);
Add(Instruction.LogicalNot, InstType.OpUnary, "!", 0);
Add(Instruction.LogicalOr, InstType.OpBinaryCom, "||", 11);
Add(Instruction.LoopBreak, InstType.OpNullary, "break");
Add(Instruction.LoopContinue, InstType.OpNullary, "continue");
Add(Instruction.PackDouble2x32, 0); // MSL does not have a 64-bit FP
Add(Instruction.PackHalf2x16, InstType.Special);
Add(Instruction.Maximum, InstType.CallBinary, "max");
Add(Instruction.MaximumU32, InstType.CallBinary, "max");
Add(Instruction.MemoryBarrier, InstType.Special);
Add(Instruction.Minimum, InstType.CallBinary, "min");
Add(Instruction.MinimumU32, InstType.CallBinary, "min");
Add(Instruction.Modulo, InstType.CallBinary, "fmod");
Add(Instruction.Multiply, InstType.OpBinaryCom, "*", 1);
Add(Instruction.MultiplyHighS32, InstType.CallBinary, "mulhi");
Add(Instruction.MultiplyHighU32, InstType.CallBinary, "mulhi");
Add(Instruction.Negate, InstType.OpUnary, "-");
Add(Instruction.ReciprocalSquareRoot, InstType.CallUnary, "rsqrt");
Add(Instruction.Return, InstType.OpNullary, "return");
Add(Instruction.Round, InstType.CallUnary, "round");
Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3);
Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3);
Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3);
Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle");
Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down");
Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up");
Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor");
Add(Instruction.Sine, InstType.CallUnary, "sin");
Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt");
Add(Instruction.Store, InstType.Special);
Add(Instruction.Subtract, InstType.OpBinary, "-", 2);
Add(Instruction.SwizzleAdd, InstType.CallTernary, HelperFunctionNames.SwizzleAdd);
Add(Instruction.TextureSample, InstType.Special);
Add(Instruction.TextureQuerySamples, InstType.Special);
Add(Instruction.TextureQuerySize, InstType.Special);
Add(Instruction.Truncate, InstType.CallUnary, "trunc");
Add(Instruction.UnpackDouble2x32, 0); // MSL does not have a 64-bit FP
Add(Instruction.UnpackHalf2x16, InstType.Special);
Add(Instruction.VectorExtract, InstType.Special);
Add(Instruction.VoteAll, InstType.CallUnary, "simd_all");
Add(Instruction.VoteAllEqual, InstType.Special);
Add(Instruction.VoteAny, InstType.CallUnary, "simd_any");
#pragma warning restore IDE0055
}
private static void Add(Instruction inst, InstType flags, string opName = null, int precedence = 0)
{
_infoTable[(int)inst] = new InstInfo(flags, opName, precedence);
}
public static InstInfo GetInstructionInfo(Instruction inst)
{
return _infoTable[(int)(inst & Instruction.Mask)];
}
public static string GetSourceExpr(CodeGenContext context, IAstNode node, AggregateType dstType)
{
return ReinterpretCast(context, node, OperandManager.GetNodeDestType(context, node), dstType);
}
public static string Enclose(string expr, IAstNode node, Instruction pInst, bool isLhs)
{
InstInfo pInfo = GetInstructionInfo(pInst);
return Enclose(expr, node, pInst, pInfo, isLhs);
}
public static string Enclose(string expr, IAstNode node, Instruction pInst, InstInfo pInfo, bool isLhs = false)
{
if (NeedsParenthesis(node, pInst, pInfo, isLhs))
{
expr = "(" + expr + ")";
}
return expr;
}
public static bool NeedsParenthesis(IAstNode node, Instruction pInst, InstInfo pInfo, bool isLhs)
{
// If the node isn't an operation, then it can only be an operand,
// and those never needs to be surrounded in parentheses.
if (node is not AstOperation operation)
{
// This is sort of a special case, if this is a negative constant,
// and it is consumed by a unary operation, we need to put on the parenthesis,
// as in MSL, while a sequence like ~-1 is valid, --2 is not.
if (IsNegativeConst(node) && pInfo.Type == InstType.OpUnary)
{
return true;
}
return false;
}
if ((pInfo.Type & (InstType.Call | InstType.Special)) != 0)
{
return false;
}
InstInfo info = _infoTable[(int)(operation.Inst & Instruction.Mask)];
if ((info.Type & (InstType.Call | InstType.Special)) != 0)
{
return false;
}
if (info.Precedence < pInfo.Precedence)
{
return false;
}
if (info.Precedence == pInfo.Precedence && isLhs)
{
return false;
}
if (pInst == operation.Inst && info.Type == InstType.OpBinaryCom)
{
return false;
}
return true;
}
private static bool IsNegativeConst(IAstNode node)
{
if (node is not AstOperand operand)
{
return false;
}
return operand.Type == OperandType.Constant && operand.Value < 0;
}
}
}

View file

@ -0,0 +1,672 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Text;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenMemory
{
public static string GenerateLoadOrStore(CodeGenContext context, AstOperation operation, bool isStore)
{
StorageKind storageKind = operation.StorageKind;
string varName;
AggregateType varType;
int srcIndex = 0;
bool isStoreOrAtomic = operation.Inst == Instruction.Store || operation.Inst.IsAtomic();
int inputsCount = isStoreOrAtomic ? operation.SourcesCount - 1 : operation.SourcesCount;
bool fieldHasPadding = false;
if (operation.Inst == Instruction.AtomicCompareAndSwap)
{
inputsCount--;
}
string fieldName = "";
switch (storageKind)
{
case StorageKind.ConstantBuffer:
case StorageKind.StorageBuffer:
if (operation.GetSource(srcIndex++) is not AstOperand bindingIndex || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
int binding = bindingIndex.Value;
BufferDefinition buffer = storageKind == StorageKind.ConstantBuffer
? context.Properties.ConstantBuffers[binding]
: context.Properties.StorageBuffers[binding];
if (operation.GetSource(srcIndex++) is not AstOperand fieldIndex || fieldIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
StructureField field = buffer.Type.Fields[fieldIndex.Value];
fieldHasPadding = buffer.Layout == BufferLayout.Std140
&& ((field.Type & AggregateType.Vector4) == 0)
&& ((field.Type & AggregateType.Array) != 0);
varName = storageKind == StorageKind.ConstantBuffer
? "constant_buffers"
: "storage_buffers";
varName += "." + buffer.Name;
varName += "->" + field.Name;
varType = field.Type;
break;
case StorageKind.LocalMemory:
case StorageKind.SharedMemory:
if (operation.GetSource(srcIndex++) is not AstOperand { Type: OperandType.Constant } bindingId)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
MemoryDefinition memory = storageKind == StorageKind.LocalMemory
? context.Properties.LocalMemories[bindingId.Value]
: context.Properties.SharedMemories[bindingId.Value];
varName = memory.Name;
varType = memory.Type;
break;
case StorageKind.Input:
case StorageKind.InputPerPatch:
case StorageKind.Output:
case StorageKind.OutputPerPatch:
if (operation.GetSource(srcIndex++) is not AstOperand varId || varId.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
IoVariable ioVariable = (IoVariable)varId.Value;
bool isOutput = storageKind.IsOutput();
bool isPerPatch = storageKind.IsPerPatch();
int location = -1;
int component = 0;
if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput))
{
if (operation.GetSource(srcIndex++) is not AstOperand vecIndex || vecIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {storageKind} storage must be a constant operand.");
}
location = vecIndex.Value;
if (operation.SourcesCount > srcIndex &&
operation.GetSource(srcIndex) is AstOperand elemIndex &&
elemIndex.Type == OperandType.Constant &&
context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, vecIndex.Value, elemIndex.Value, isOutput))
{
component = elemIndex.Value;
srcIndex++;
}
}
(varName, varType) = IoMap.GetMslBuiltIn(
context.Definitions,
ioVariable,
location,
component,
isOutput,
isPerPatch);
break;
default:
throw new InvalidOperationException($"Invalid storage kind {storageKind}.");
}
for (; srcIndex < inputsCount; srcIndex++)
{
IAstNode src = operation.GetSource(srcIndex);
if ((varType & AggregateType.ElementCountMask) != 0 &&
srcIndex == inputsCount - 1 &&
src is AstOperand elementIndex &&
elementIndex.Type == OperandType.Constant)
{
varName += "." + "xyzw"[elementIndex.Value & 3];
}
else
{
varName += $"[{GetSourceExpr(context, src, AggregateType.S32)}]";
}
}
varName += fieldName;
varName += fieldHasPadding ? ".x" : "";
if (isStore)
{
varType &= AggregateType.ElementTypeMask;
varName = $"{varName} = {GetSourceExpr(context, operation.GetSource(srcIndex), varType)}";
}
return varName;
}
public static string ImageLoadOrStore(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
bool isArray = (texOp.Type & SamplerType.Array) != 0;
var texCallBuilder = new StringBuilder();
int srcIndex = 0;
string Src(AggregateType type)
{
return GetSourceExpr(context, texOp.GetSource(srcIndex++), type);
}
string imageName = GetImageName(context, texOp, ref srcIndex);
texCallBuilder.Append(imageName);
texCallBuilder.Append('.');
if (texOp.Inst == Instruction.ImageAtomic)
{
texCallBuilder.Append((texOp.Flags & TextureFlags.AtomicMask) switch
{
TextureFlags.Add => "atomic_fetch_add",
TextureFlags.Minimum => "atomic_min",
TextureFlags.Maximum => "atomic_max",
TextureFlags.Increment => "atomic_fetch_add",
TextureFlags.Decrement => "atomic_fetch_sub",
TextureFlags.BitwiseAnd => "atomic_fetch_and",
TextureFlags.BitwiseOr => "atomic_fetch_or",
TextureFlags.BitwiseXor => "atomic_fetch_xor",
TextureFlags.Swap => "atomic_exchange",
TextureFlags.CAS => "atomic_compare_exchange_weak",
_ => "atomic_fetch_add",
});
}
else
{
texCallBuilder.Append(texOp.Inst == Instruction.ImageLoad ? "read" : "write");
}
texCallBuilder.Append('(');
var coordsBuilder = new StringBuilder();
int coordsCount = texOp.Type.GetDimensions();
if (coordsCount > 1)
{
string[] elems = new string[coordsCount];
for (int index = 0; index < coordsCount; index++)
{
elems[index] = Src(AggregateType.S32);
}
coordsBuilder.Append($"uint{coordsCount}({string.Join(", ", elems)})");
}
else
{
coordsBuilder.Append($"uint({Src(AggregateType.S32)})");
}
if (isArray)
{
coordsBuilder.Append(", ");
coordsBuilder.Append(Src(AggregateType.S32));
}
if (texOp.Inst == Instruction.ImageStore)
{
AggregateType type = texOp.Format.GetComponentType();
string[] cElems = new string[4];
for (int index = 0; index < 4; index++)
{
if (srcIndex < texOp.SourcesCount)
{
cElems[index] = Src(type);
}
else
{
cElems[index] = type switch
{
AggregateType.S32 => NumberFormatter.FormatInt(0),
AggregateType.U32 => NumberFormatter.FormatUint(0),
_ => NumberFormatter.FormatFloat(0),
};
}
}
string prefix = type switch
{
AggregateType.S32 => "int",
AggregateType.U32 => "uint",
AggregateType.FP32 => "float",
_ => string.Empty,
};
texCallBuilder.Append($"{prefix}4({string.Join(", ", cElems)})");
texCallBuilder.Append(", ");
}
texCallBuilder.Append(coordsBuilder);
if (texOp.Inst == Instruction.ImageAtomic)
{
texCallBuilder.Append(", ");
AggregateType type = texOp.Format.GetComponentType();
if ((texOp.Flags & TextureFlags.AtomicMask) == TextureFlags.CAS)
{
texCallBuilder.Append(Src(type)); // Compare value.
}
string value = (texOp.Flags & TextureFlags.AtomicMask) switch
{
TextureFlags.Increment => NumberFormatter.FormatInt(1, type), // TODO: Clamp value
TextureFlags.Decrement => NumberFormatter.FormatInt(-1, type), // TODO: Clamp value
_ => Src(type),
};
texCallBuilder.Append(value);
// This doesn't match what the MSL spec document says so either
// it is wrong or the MSL compiler has a bug.
texCallBuilder.Append(")[0]");
}
else
{
texCallBuilder.Append(')');
if (texOp.Inst == Instruction.ImageLoad)
{
texCallBuilder.Append(GetMaskMultiDest(texOp.Index));
}
}
return texCallBuilder.ToString();
}
public static string Load(CodeGenContext context, AstOperation operation)
{
return GenerateLoadOrStore(context, operation, isStore: false);
}
public static string Lod(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
int coordsCount = texOp.Type.GetDimensions();
int coordsIndex = 0;
string textureName = GetTextureName(context, texOp, ref coordsIndex);
string samplerName = GetSamplerName(context, texOp, ref coordsIndex);
string coordsExpr;
if (coordsCount > 1)
{
string[] elems = new string[coordsCount];
for (int index = 0; index < coordsCount; index++)
{
elems[index] = GetSourceExpr(context, texOp.GetSource(coordsIndex + index), AggregateType.FP32);
}
coordsExpr = "float" + coordsCount + "(" + string.Join(", ", elems) + ")";
}
else
{
coordsExpr = GetSourceExpr(context, texOp.GetSource(coordsIndex), AggregateType.FP32);
}
var clamped = $"{textureName}.calculate_clamped_lod({samplerName}, {coordsExpr})";
var unclamped = $"{textureName}.calculate_unclamped_lod({samplerName}, {coordsExpr})";
return $"float2({clamped}, {unclamped}){GetMask(texOp.Index)}";
}
public static string Store(CodeGenContext context, AstOperation operation)
{
return GenerateLoadOrStore(context, operation, isStore: true);
}
public static string TextureSample(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
bool isGather = (texOp.Flags & TextureFlags.Gather) != 0;
bool hasDerivatives = (texOp.Flags & TextureFlags.Derivatives) != 0;
bool intCoords = (texOp.Flags & TextureFlags.IntCoords) != 0;
bool hasLodBias = (texOp.Flags & TextureFlags.LodBias) != 0;
bool hasLodLevel = (texOp.Flags & TextureFlags.LodLevel) != 0;
bool hasOffset = (texOp.Flags & TextureFlags.Offset) != 0;
bool hasOffsets = (texOp.Flags & TextureFlags.Offsets) != 0;
bool isArray = (texOp.Type & SamplerType.Array) != 0;
bool isShadow = (texOp.Type & SamplerType.Shadow) != 0;
var texCallBuilder = new StringBuilder();
bool colorIsVector = isGather || !isShadow;
int srcIndex = 0;
string Src(AggregateType type)
{
return GetSourceExpr(context, texOp.GetSource(srcIndex++), type);
}
string textureName = GetTextureName(context, texOp, ref srcIndex);
string samplerName = GetSamplerName(context, texOp, ref srcIndex);
texCallBuilder.Append(textureName);
texCallBuilder.Append('.');
if (intCoords)
{
texCallBuilder.Append("read(");
}
else
{
if (isGather)
{
texCallBuilder.Append("gather");
}
else
{
texCallBuilder.Append("sample");
}
if (isShadow)
{
texCallBuilder.Append("_compare");
}
texCallBuilder.Append($"({samplerName}, ");
}
int coordsCount = texOp.Type.GetDimensions();
int pCount = coordsCount;
bool appended = false;
void Append(string str)
{
if (appended)
{
texCallBuilder.Append(", ");
}
else
{
appended = true;
}
texCallBuilder.Append(str);
}
AggregateType coordType = intCoords ? AggregateType.S32 : AggregateType.FP32;
string AssemblePVector(int count)
{
string coords;
if (count > 1)
{
string[] elems = new string[count];
for (int index = 0; index < count; index++)
{
elems[index] = Src(coordType);
}
coords = string.Join(", ", elems);
}
else
{
coords = Src(coordType);
}
string prefix = intCoords ? "uint" : "float";
return prefix + (count > 1 ? count : "") + "(" + coords + ")";
}
Append(AssemblePVector(pCount));
if (isArray)
{
Append(Src(AggregateType.S32));
}
if (isShadow)
{
Append(Src(AggregateType.FP32));
}
if (hasDerivatives)
{
Logger.Warning?.PrintMsg(LogClass.Gpu, "Unused sampler derivatives!");
}
if (hasLodBias)
{
Logger.Warning?.PrintMsg(LogClass.Gpu, "Unused sample LOD bias!");
}
if (hasLodLevel)
{
if (intCoords)
{
Append(Src(coordType));
}
else
{
Append($"level({Src(coordType)})");
}
}
string AssembleOffsetVector(int count)
{
if (count > 1)
{
string[] elems = new string[count];
for (int index = 0; index < count; index++)
{
elems[index] = Src(AggregateType.S32);
}
return "int" + count + "(" + string.Join(", ", elems) + ")";
}
else
{
return Src(AggregateType.S32);
}
}
// TODO: Support reads with offsets
if (!intCoords)
{
if (hasOffset)
{
Append(AssembleOffsetVector(coordsCount));
}
else if (hasOffsets)
{
Logger.Warning?.PrintMsg(LogClass.Gpu, "Multiple offsets on gathers are not yet supported!");
}
}
texCallBuilder.Append(')');
texCallBuilder.Append(colorIsVector ? GetMaskMultiDest(texOp.Index) : "");
return texCallBuilder.ToString();
}
private static string GetTextureName(CodeGenContext context, AstTextureOperation texOp, ref int srcIndex)
{
TextureDefinition textureDefinition = context.Properties.Textures[texOp.GetTextureSetAndBinding()];
string name = textureDefinition.Name;
string setName = Declarations.GetNameForSet(textureDefinition.Set, true);
if (textureDefinition.ArrayLength != 1)
{
name = $"{name}[{GetSourceExpr(context, texOp.GetSource(srcIndex++), AggregateType.S32)}]";
}
return $"{setName}.tex_{name}";
}
private static string GetSamplerName(CodeGenContext context, AstTextureOperation texOp, ref int srcIndex)
{
var index = texOp.IsSeparate ? texOp.GetSamplerSetAndBinding() : texOp.GetTextureSetAndBinding();
var sourceIndex = texOp.IsSeparate ? srcIndex++ : srcIndex + 1;
TextureDefinition samplerDefinition = context.Properties.Textures[index];
string name = samplerDefinition.Name;
string setName = Declarations.GetNameForSet(samplerDefinition.Set, true);
if (samplerDefinition.ArrayLength != 1)
{
name = $"{name}[{GetSourceExpr(context, texOp.GetSource(sourceIndex), AggregateType.S32)}]";
}
return $"{setName}.samp_{name}";
}
private static string GetImageName(CodeGenContext context, AstTextureOperation texOp, ref int srcIndex)
{
TextureDefinition imageDefinition = context.Properties.Images[texOp.GetTextureSetAndBinding()];
string name = imageDefinition.Name;
string setName = Declarations.GetNameForSet(imageDefinition.Set, true);
if (imageDefinition.ArrayLength != 1)
{
name = $"{name}[{GetSourceExpr(context, texOp.GetSource(srcIndex++), AggregateType.S32)}]";
}
return $"{setName}.{name}";
}
private static string GetMaskMultiDest(int mask)
{
if (mask == 0x0)
{
return "";
}
string swizzle = ".";
for (int i = 0; i < 4; i++)
{
if ((mask & (1 << i)) != 0)
{
swizzle += "xyzw"[i];
}
}
return swizzle;
}
public static string TextureQuerySamples(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
int srcIndex = 0;
string textureName = GetTextureName(context, texOp, ref srcIndex);
return $"{textureName}.get_num_samples()";
}
public static string TextureQuerySize(CodeGenContext context, AstOperation operation)
{
AstTextureOperation texOp = (AstTextureOperation)operation;
var texCallBuilder = new StringBuilder();
int srcIndex = 0;
string textureName = GetTextureName(context, texOp, ref srcIndex);
texCallBuilder.Append(textureName);
texCallBuilder.Append('.');
if (texOp.Index == 3)
{
texCallBuilder.Append("get_num_mip_levels()");
}
else
{
context.Properties.Textures.TryGetValue(texOp.GetTextureSetAndBinding(), out TextureDefinition definition);
bool hasLod = !definition.Type.HasFlag(SamplerType.Multisample) && (definition.Type & SamplerType.Mask) != SamplerType.TextureBuffer;
bool isArray = definition.Type.HasFlag(SamplerType.Array);
texCallBuilder.Append("get_");
if (texOp.Index == 0)
{
texCallBuilder.Append("width");
}
else if (texOp.Index == 1)
{
texCallBuilder.Append("height");
}
else
{
if (isArray)
{
texCallBuilder.Append("array_size");
}
else
{
texCallBuilder.Append("depth");
}
}
texCallBuilder.Append('(');
if (hasLod && !isArray)
{
IAstNode lod = operation.GetSource(0);
string lodExpr = GetSourceExpr(context, lod, GetSrcVarType(operation.Inst, 0));
texCallBuilder.Append(lodExpr);
}
texCallBuilder.Append(')');
}
return texCallBuilder.ToString();
}
public static string PackHalf2x16(CodeGenContext context, AstOperation operation)
{
IAstNode src0 = operation.GetSource(0);
IAstNode src1 = operation.GetSource(1);
string src0Expr = GetSourceExpr(context, src0, GetSrcVarType(operation.Inst, 0));
string src1Expr = GetSourceExpr(context, src1, GetSrcVarType(operation.Inst, 1));
return $"as_type<uint>(half2({src0Expr}, {src1Expr}))";
}
public static string UnpackHalf2x16(CodeGenContext context, AstOperation operation)
{
IAstNode src = operation.GetSource(0);
string srcExpr = GetSourceExpr(context, src, GetSrcVarType(operation.Inst, 0));
return $"float2(as_type<half2>({srcExpr})){GetMask(operation.Index)}";
}
private static string GetMask(int index)
{
return $".{"xy".AsSpan(index, 1)}";
}
}
}

View file

@ -0,0 +1,32 @@
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class InstGenVector
{
public static string VectorExtract(CodeGenContext context, AstOperation operation)
{
IAstNode vector = operation.GetSource(0);
IAstNode index = operation.GetSource(1);
string vectorExpr = GetSourceExpr(context, vector, OperandManager.GetNodeDestType(context, vector));
if (index is AstOperand indexOperand && indexOperand.Type == OperandType.Constant)
{
char elem = "xyzw"[indexOperand.Value];
return $"{vectorExpr}.{elem}";
}
else
{
string indexExpr = GetSourceExpr(context, index, GetSrcVarType(operation.Inst, 1));
return $"{vectorExpr}[{indexExpr}]";
}
}
}
}

View file

@ -0,0 +1,18 @@
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
readonly struct InstInfo
{
public InstType Type { get; }
public string OpName { get; }
public int Precedence { get; }
public InstInfo(InstType type, string opName, int precedence)
{
Type = type;
OpName = opName;
Precedence = precedence;
}
}
}

View file

@ -0,0 +1,35 @@
using System;
using System.Diagnostics.CodeAnalysis;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
[Flags]
[SuppressMessage("Design", "CA1069: Enums values should not be duplicated")]
public enum InstType
{
OpNullary = Op | 0,
OpUnary = Op | 1,
OpBinary = Op | 2,
OpBinaryCom = Op | 2 | Commutative,
OpTernary = Op | 3,
CallNullary = Call | 0,
CallUnary = Call | 1,
CallBinary = Call | 2,
CallTernary = Call | 3,
CallQuaternary = Call | 4,
// The atomic instructions have one extra operand,
// for the storage slot and offset pair.
AtomicBinary = Call | Atomic | 3,
AtomicTernary = Call | Atomic | 4,
Commutative = 1 << 8,
Op = 1 << 9,
Call = 1 << 10,
Atomic = 1 << 11,
Special = 1 << 12,
ArityMask = 0xff,
}
}

View file

@ -0,0 +1,83 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.Translation;
using System.Globalization;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
{
static class IoMap
{
public static (string, AggregateType) GetMslBuiltIn(
ShaderDefinitions definitions,
IoVariable ioVariable,
int location,
int component,
bool isOutput,
bool isPerPatch)
{
var returnValue = ioVariable switch
{
IoVariable.BaseInstance => ("base_instance", AggregateType.U32),
IoVariable.BaseVertex => ("base_vertex", AggregateType.U32),
IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.ClipDistance => ("out.clip_distance", AggregateType.Array | AggregateType.FP32),
IoVariable.FragmentOutputColor => ($"out.color{location}", definitions.GetFragmentOutputColorType(location)),
IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32),
IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool),
IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
IoVariable.InstanceId => ("instance_id", AggregateType.U32),
IoVariable.InstanceIndex => ("instance_index", AggregateType.U32),
IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
IoVariable.PointCoord => ("in.point_coord", AggregateType.Vector2 | AggregateType.FP32),
IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.PrimitiveId => ("in.primitive_id", AggregateType.U32),
IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
IoVariable.VertexId => ("vertex_id", AggregateType.S32),
// gl_VertexIndex does not have a direct equivalent in MSL
IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
_ => (null, AggregateType.Invalid),
};
if (returnValue.Item2 == AggregateType.Invalid)
{
Logger.Warning?.PrintMsg(LogClass.Gpu, $"Unable to find type for IoVariable {ioVariable}!");
}
return returnValue;
}
private static (string, AggregateType) GetUserDefinedVariableName(ShaderDefinitions definitions, int location, int component, bool isOutput, bool isPerPatch)
{
string name = isPerPatch
? Defaults.PerPatchAttributePrefix
: (isOutput ? Defaults.OAttributePrefix : Defaults.IAttributePrefix);
if (location < 0)
{
return (name, definitions.GetUserDefinedType(0, isOutput));
}
name += location.ToString(CultureInfo.InvariantCulture);
if (definitions.HasPerLocationInputOrOutputComponent(IoVariable.UserDefined, location, component, isOutput))
{
name += "_" + "xyzw"[component & 3];
}
string prefix = isOutput ? "out" : "in";
return (prefix + "." + name, definitions.GetUserDefinedType(location, isOutput));
}
}
}

View file

@ -0,0 +1,286 @@
using Ryujinx.Common.Logging;
using Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Linq;
using static Ryujinx.Graphics.Shader.CodeGen.Msl.TypeConversion;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class MslGenerator
{
public static string Generate(StructuredProgramInfo info, CodeGenParameters parameters)
{
if (parameters.Definitions.Stage is not (ShaderStage.Vertex or ShaderStage.Fragment or ShaderStage.Compute))
{
Logger.Warning?.Print(LogClass.Gpu, $"Attempted to generate unsupported shader type {parameters.Definitions.Stage}!");
return "";
}
CodeGenContext context = new(info, parameters);
var sets = Declarations.Declare(context, info);
if (info.Functions.Count != 0)
{
for (int i = 1; i < info.Functions.Count; i++)
{
PrintFunction(context, info.Functions[i], parameters.Definitions.Stage, sets);
context.AppendLine();
}
}
PrintFunction(context, info.Functions[0], parameters.Definitions.Stage, sets, true);
return context.GetCode();
}
private static void PrintFunction(CodeGenContext context, StructuredFunction function, ShaderStage stage, int[] sets, bool isMainFunc = false)
{
context.CurrentFunction = function;
context.AppendLine(GetFunctionSignature(context, function, stage, sets, isMainFunc));
context.EnterScope();
Declarations.DeclareLocals(context, function, stage, isMainFunc);
PrintBlock(context, function.MainBlock, isMainFunc);
// In case the shader hasn't returned, return
if (isMainFunc && stage != ShaderStage.Compute)
{
context.AppendLine("return out;");
}
context.LeaveScope();
}
private static string GetFunctionSignature(
CodeGenContext context,
StructuredFunction function,
ShaderStage stage,
int[] sets,
bool isMainFunc = false)
{
int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0);
bool needsThreadIndex = false;
// TODO: Replace this with a proper flag
if (function.Name.Contains("Shuffle"))
{
needsThreadIndex = true;
additionalArgCount++;
}
string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length];
// All non-main functions need to be able to access the support_buffer as well
if (!isMainFunc)
{
if (stage != ShaderStage.Compute)
{
args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in";
args[1] = "constant ConstantBuffers &constant_buffers";
args[2] = "device StorageBuffers &storage_buffers";
if (needsThreadIndex)
{
args[3] = "uint thread_index_in_simdgroup";
}
}
else
{
args[0] = "constant ConstantBuffers &constant_buffers";
args[1] = "device StorageBuffers &storage_buffers";
if (needsThreadIndex)
{
args[2] = "uint thread_index_in_simdgroup";
}
}
}
int argIndex = additionalArgCount;
for (int i = 0; i < function.InArguments.Length; i++)
{
args[argIndex++] = $"{Declarations.GetVarTypeName(function.InArguments[i])} {OperandManager.GetArgumentName(i)}";
}
for (int i = 0; i < function.OutArguments.Length; i++)
{
int j = i + function.InArguments.Length;
args[argIndex++] = $"thread {Declarations.GetVarTypeName(function.OutArguments[i])} &{OperandManager.GetArgumentName(j)}";
}
string funcKeyword = "inline";
string funcName = null;
string returnType = Declarations.GetVarTypeName(function.ReturnType);
if (isMainFunc)
{
if (stage == ShaderStage.Vertex)
{
funcKeyword = "vertex";
funcName = "vertexMain";
returnType = "VertexOut";
}
else if (stage == ShaderStage.Fragment)
{
funcKeyword = "fragment";
funcName = "fragmentMain";
returnType = "FragmentOut";
}
else if (stage == ShaderStage.Compute)
{
funcKeyword = "kernel";
funcName = "kernelMain";
returnType = "void";
}
if (stage == ShaderStage.Vertex)
{
args = args.Prepend("VertexIn in [[stage_in]]").ToArray();
}
else if (stage == ShaderStage.Fragment)
{
args = args.Prepend("FragmentIn in [[stage_in]]").ToArray();
}
// TODO: add these only if they are used
if (stage == ShaderStage.Vertex)
{
args = args.Append("uint vertex_id [[vertex_id]]").ToArray();
args = args.Append("uint instance_id [[instance_id]]").ToArray();
args = args.Append("uint base_instance [[base_instance]]").ToArray();
args = args.Append("uint base_vertex [[base_vertex]]").ToArray();
}
else if (stage == ShaderStage.Compute)
{
args = args.Append("uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_grid [[thread_position_in_grid]]").ToArray();
args = args.Append("uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]]").ToArray();
args = args.Append("uint thread_index_in_simdgroup [[thread_index_in_simdgroup]]").ToArray();
}
args = args.Append($"constant ConstantBuffers &constant_buffers [[buffer({Defaults.ConstantBuffersIndex})]]").ToArray();
args = args.Append($"device StorageBuffers &storage_buffers [[buffer({Defaults.StorageBuffersIndex})]]").ToArray();
foreach (var set in sets)
{
var bindingIndex = set + Defaults.BaseSetIndex;
args = args.Append($"constant {Declarations.GetNameForSet(set)} &{Declarations.GetNameForSet(set, true)} [[buffer({bindingIndex})]]").ToArray();
}
}
var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}(";
var indent = new string(' ', funcPrefix.Length);
return $"{funcPrefix}{string.Join($", \n{indent}", args)})";
}
private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction)
{
AstBlockVisitor visitor = new(block);
visitor.BlockEntered += (sender, e) =>
{
switch (e.Block.Type)
{
case AstBlockType.DoWhile:
context.AppendLine("do");
break;
case AstBlockType.Else:
context.AppendLine("else");
break;
case AstBlockType.ElseIf:
context.AppendLine($"else if ({GetCondExpr(context, e.Block.Condition)})");
break;
case AstBlockType.If:
context.AppendLine($"if ({GetCondExpr(context, e.Block.Condition)})");
break;
default:
throw new InvalidOperationException($"Found unexpected block type \"{e.Block.Type}\".");
}
context.EnterScope();
};
visitor.BlockLeft += (sender, e) =>
{
context.LeaveScope();
if (e.Block.Type == AstBlockType.DoWhile)
{
context.AppendLine($"while ({GetCondExpr(context, e.Block.Condition)});");
}
};
bool supportsBarrierDivergence = context.HostCapabilities.SupportsShaderBarrierDivergence;
bool mayHaveReturned = false;
foreach (IAstNode node in visitor.Visit())
{
if (node is AstOperation operation)
{
if (!supportsBarrierDivergence)
{
if (operation.Inst == IntermediateRepresentation.Instruction.Barrier)
{
// Barrier on divergent control flow paths may cause the GPU to hang,
// so skip emitting the barrier for those cases.
if (visitor.Block.Type != AstBlockType.Main || mayHaveReturned || !isMainFunction)
{
context.Logger.Log($"Shader has barrier on potentially divergent block, the barrier will be removed.");
continue;
}
}
else if (operation.Inst == IntermediateRepresentation.Instruction.Return)
{
mayHaveReturned = true;
}
}
string expr = InstGen.GetExpression(context, operation);
if (expr != null)
{
context.AppendLine(expr + ";");
}
}
else if (node is AstAssignment assignment)
{
AggregateType dstType = OperandManager.GetNodeDestType(context, assignment.Destination);
AggregateType srcType = OperandManager.GetNodeDestType(context, assignment.Source);
string dest = InstGen.GetExpression(context, assignment.Destination);
string src = ReinterpretCast(context, assignment.Source, srcType, dstType);
context.AppendLine(dest + " = " + src + ";");
}
else if (node is AstComment comment)
{
context.AppendLine("// " + comment.Comment);
}
else
{
throw new InvalidOperationException($"Found unexpected node type \"{node?.GetType().Name ?? "null"}\".");
}
}
}
private static string GetCondExpr(CodeGenContext context, IAstNode cond)
{
AggregateType srcType = OperandManager.GetNodeDestType(context, cond);
return ReinterpretCast(context, cond, srcType, AggregateType.Bool);
}
}
}

View file

@ -0,0 +1,94 @@
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Globalization;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class NumberFormatter
{
private const int MaxDecimal = 256;
public static bool TryFormat(int value, AggregateType dstType, out string formatted)
{
switch (dstType)
{
case AggregateType.FP32:
return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted);
case AggregateType.S32:
formatted = FormatInt(value);
break;
case AggregateType.U32:
formatted = FormatUint((uint)value);
break;
case AggregateType.Bool:
formatted = value != 0 ? "true" : "false";
break;
default:
throw new ArgumentException($"Invalid variable type \"{dstType}\".");
}
return true;
}
public static string FormatFloat(float value)
{
if (!TryFormatFloat(value, out string formatted))
{
throw new ArgumentException("Failed to convert float value to string.");
}
return formatted;
}
public static bool TryFormatFloat(float value, out string formatted)
{
if (float.IsNaN(value) || float.IsInfinity(value))
{
formatted = null;
return false;
}
formatted = value.ToString("G9", CultureInfo.InvariantCulture);
if (!(formatted.Contains('.') ||
formatted.Contains('e') ||
formatted.Contains('E')))
{
formatted += ".0f";
}
return true;
}
public static string FormatInt(int value, AggregateType dstType)
{
return dstType switch
{
AggregateType.S32 => FormatInt(value),
AggregateType.U32 => FormatUint((uint)value),
_ => throw new ArgumentException($"Invalid variable type \"{dstType}\".")
};
}
public static string FormatInt(int value)
{
if (value <= MaxDecimal && value >= -MaxDecimal)
{
return value.ToString(CultureInfo.InvariantCulture);
}
return $"as_type<int>(0x{value.ToString("X", CultureInfo.InvariantCulture)})";
}
public static string FormatUint(uint value)
{
if (value <= MaxDecimal && value >= 0)
{
return value.ToString(CultureInfo.InvariantCulture) + "u";
}
return $"as_type<uint>(0x{value.ToString("X", CultureInfo.InvariantCulture)})";
}
}
}

View file

@ -0,0 +1,176 @@
using Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
class OperandManager
{
private readonly Dictionary<AstOperand, string> _locals;
public OperandManager()
{
_locals = new Dictionary<AstOperand, string>();
}
public string DeclareLocal(AstOperand operand)
{
string name = $"{Defaults.LocalNamePrefix}_{_locals.Count}";
_locals.Add(operand, name);
return name;
}
public string GetExpression(CodeGenContext context, AstOperand operand)
{
return operand.Type switch
{
OperandType.Argument => GetArgumentName(operand.Value),
OperandType.Constant => NumberFormatter.FormatInt(operand.Value),
OperandType.LocalVariable => _locals[operand],
OperandType.Undefined => Defaults.UndefinedName,
_ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\"."),
};
}
public static string GetArgumentName(int argIndex)
{
return $"{Defaults.ArgumentNamePrefix}{argIndex}";
}
public static AggregateType GetNodeDestType(CodeGenContext context, IAstNode node)
{
if (node is AstOperation operation)
{
if (operation.Inst == Instruction.Load || operation.Inst.IsAtomic())
{
switch (operation.StorageKind)
{
case StorageKind.ConstantBuffer:
case StorageKind.StorageBuffer:
if (operation.GetSource(0) is not AstOperand bindingIndex || bindingIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
if (operation.GetSource(1) is not AstOperand fieldIndex || fieldIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
BufferDefinition buffer = operation.StorageKind == StorageKind.ConstantBuffer
? context.Properties.ConstantBuffers[bindingIndex.Value]
: context.Properties.StorageBuffers[bindingIndex.Value];
StructureField field = buffer.Type.Fields[fieldIndex.Value];
return field.Type & AggregateType.ElementTypeMask;
case StorageKind.LocalMemory:
case StorageKind.SharedMemory:
if (operation.GetSource(0) is not AstOperand { Type: OperandType.Constant } bindingId)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
MemoryDefinition memory = operation.StorageKind == StorageKind.LocalMemory
? context.Properties.LocalMemories[bindingId.Value]
: context.Properties.SharedMemories[bindingId.Value];
return memory.Type & AggregateType.ElementTypeMask;
case StorageKind.Input:
case StorageKind.InputPerPatch:
case StorageKind.Output:
case StorageKind.OutputPerPatch:
if (operation.GetSource(0) is not AstOperand varId || varId.Type != OperandType.Constant)
{
throw new InvalidOperationException($"First input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
IoVariable ioVariable = (IoVariable)varId.Value;
bool isOutput = operation.StorageKind == StorageKind.Output || operation.StorageKind == StorageKind.OutputPerPatch;
bool isPerPatch = operation.StorageKind == StorageKind.InputPerPatch || operation.StorageKind == StorageKind.OutputPerPatch;
int location = 0;
int component = 0;
if (context.Definitions.HasPerLocationInputOrOutput(ioVariable, isOutput))
{
if (operation.GetSource(1) is not AstOperand vecIndex || vecIndex.Type != OperandType.Constant)
{
throw new InvalidOperationException($"Second input of {operation.Inst} with {operation.StorageKind} storage must be a constant operand.");
}
location = vecIndex.Value;
if (operation.SourcesCount > 2 &&
operation.GetSource(2) is AstOperand elemIndex &&
elemIndex.Type == OperandType.Constant &&
context.Definitions.HasPerLocationInputOrOutputComponent(ioVariable, location, elemIndex.Value, isOutput))
{
component = elemIndex.Value;
}
}
(_, AggregateType varType) = IoMap.GetMslBuiltIn(
context.Definitions,
ioVariable,
location,
component,
isOutput,
isPerPatch);
return varType & AggregateType.ElementTypeMask;
}
}
else if (operation.Inst == Instruction.Call)
{
AstOperand funcId = (AstOperand)operation.GetSource(0);
Debug.Assert(funcId.Type == OperandType.Constant);
return context.GetFunction(funcId.Value).ReturnType;
}
else if (operation.Inst == Instruction.VectorExtract)
{
return GetNodeDestType(context, operation.GetSource(0)) & ~AggregateType.ElementCountMask;
}
else if (operation is AstTextureOperation texOp)
{
if (texOp.Inst == Instruction.ImageLoad ||
texOp.Inst == Instruction.ImageStore ||
texOp.Inst == Instruction.ImageAtomic)
{
return texOp.GetVectorType(texOp.Format.GetComponentType());
}
else if (texOp.Inst == Instruction.TextureSample)
{
return texOp.GetVectorType(GetDestVarType(operation.Inst));
}
}
return GetDestVarType(operation.Inst);
}
else if (node is AstOperand operand)
{
if (operand.Type == OperandType.Argument)
{
int argIndex = operand.Value;
return context.CurrentFunction.GetArgumentType(argIndex);
}
return OperandInfo.GetVarType(operand);
}
else
{
throw new ArgumentException($"Invalid node type \"{node?.GetType().Name ?? "null"}\".");
}
}
}
}

View file

@ -0,0 +1,93 @@
using Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
using Ryujinx.Graphics.Shader.StructuredIr;
using Ryujinx.Graphics.Shader.Translation;
using System;
namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{
static class TypeConversion
{
public static string ReinterpretCast(
CodeGenContext context,
IAstNode node,
AggregateType srcType,
AggregateType dstType)
{
if (node is AstOperand operand && operand.Type == OperandType.Constant)
{
if (NumberFormatter.TryFormat(operand.Value, dstType, out string formatted))
{
return formatted;
}
}
string expr = InstGen.GetExpression(context, node);
return ReinterpretCast(expr, node, srcType, dstType);
}
private static string ReinterpretCast(string expr, IAstNode node, AggregateType srcType, AggregateType dstType)
{
if (srcType == dstType)
{
return expr;
}
if (srcType == AggregateType.FP32)
{
switch (dstType)
{
case AggregateType.Bool:
return $"(as_type<int>({expr}) != 0)";
case AggregateType.S32:
return $"as_type<int>({expr})";
case AggregateType.U32:
return $"as_type<uint>({expr})";
}
}
else if (dstType == AggregateType.FP32)
{
switch (srcType)
{
case AggregateType.Bool:
return $"as_type<float>({ReinterpretBoolToInt(expr, node, AggregateType.S32)})";
case AggregateType.S32:
return $"as_type<float>({expr})";
case AggregateType.U32:
return $"as_type<float>({expr})";
}
}
else if (srcType == AggregateType.Bool)
{
return ReinterpretBoolToInt(expr, node, dstType);
}
else if (dstType == AggregateType.Bool)
{
expr = InstGenHelper.Enclose(expr, node, Instruction.CompareNotEqual, isLhs: true);
return $"({expr} != 0)";
}
else if (dstType == AggregateType.S32)
{
return $"int({expr})";
}
else if (dstType == AggregateType.U32)
{
return $"uint({expr})";
}
throw new ArgumentException($"Invalid reinterpret cast from \"{srcType}\" to \"{dstType}\".");
}
private static string ReinterpretBoolToInt(string expr, IAstNode node, AggregateType dstType)
{
string trueExpr = NumberFormatter.FormatInt(IrConsts.True, dstType);
string falseExpr = NumberFormatter.FormatInt(IrConsts.False, dstType);
expr = InstGenHelper.Enclose(expr, node, Instruction.ConditionalSelect, isLhs: false);
return $"({expr} ? {trueExpr} : {falseExpr})";
}
}
}

View file

@ -15,4 +15,11 @@
<EmbeddedResource Include="CodeGen\Glsl\HelperFunctions\SwizzleAdd.glsl" />
</ItemGroup>
<ItemGroup>
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindLSB.metal" />
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindMSBS32.metal" />
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\FindMSBU32.metal" />
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\SwizzleAdd.metal" />
<EmbeddedResource Include="CodeGen\Msl\HelperFunctions\Precise.metal" />
</ItemGroup>
</Project>

View file

@ -155,5 +155,51 @@ namespace Ryujinx.Graphics.Shader
return typeName;
}
public static string ToMslTextureType(this SamplerType type, AggregateType aggregateType, bool image = false)
{
string typeName;
if ((type & SamplerType.Shadow) != 0)
{
typeName = (type & SamplerType.Mask) switch
{
SamplerType.Texture2D => "depth2d",
SamplerType.TextureCube => "depthcube",
_ => throw new ArgumentException($"Invalid shadow texture type \"{type}\"."),
};
}
else
{
typeName = (type & SamplerType.Mask) switch
{
SamplerType.Texture1D => "texture1d",
SamplerType.TextureBuffer => "texture_buffer",
SamplerType.Texture2D => "texture2d",
SamplerType.Texture3D => "texture3d",
SamplerType.TextureCube => "texturecube",
_ => throw new ArgumentException($"Invalid texture type \"{type}\"."),
};
}
if ((type & SamplerType.Multisample) != 0)
{
typeName += "_ms";
}
if ((type & SamplerType.Array) != 0)
{
typeName += "_array";
}
var format = aggregateType switch
{
AggregateType.S32 => "int",
AggregateType.U32 => "uint",
_ => "float"
};
return $"{typeName}<{format}{(image ? ", access::read_write" : "")}>";
}
}
}

View file

@ -7,7 +7,14 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
{
MultiplyHighS32 = 1 << 2,
MultiplyHighU32 = 1 << 3,
FindLSB = 1 << 5,
FindMSBS32 = 1 << 6,
FindMSBU32 = 1 << 7,
SwizzleAdd = 1 << 10,
FSI = 1 << 11,
Precise = 1 << 13
}
}

View file

@ -18,9 +18,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
ShaderDefinitions definitions,
ResourceManager resourceManager,
TargetLanguage targetLanguage,
bool precise,
bool debugMode)
{
StructuredProgramContext context = new(attributeUsage, definitions, resourceManager, debugMode);
StructuredProgramContext context = new(attributeUsage, definitions, resourceManager, precise, debugMode);
for (int funcIndex = 0; funcIndex < functions.Count; funcIndex++)
{
@ -321,8 +322,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
}
// Those instructions needs to be emulated by using helper functions,
// because they are NVIDIA specific. Those flags helps the backend to
// decide which helper functions are needed on the final generated code.
// because they are NVIDIA specific or because the target language has
// no direct equivalent. Those flags helps the backend to decide which
// helper functions are needed on the final generated code.
switch (operation.Inst)
{
case Instruction.MultiplyHighS32:
@ -331,6 +333,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
case Instruction.MultiplyHighU32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32;
break;
case Instruction.FindLSB:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindLSB;
break;
case Instruction.FindMSBS32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindMSBS32;
break;
case Instruction.FindMSBU32:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.FindMSBU32;
break;
case Instruction.SwizzleAdd:
context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd;
break;

View file

@ -36,9 +36,10 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
AttributeUsage attributeUsage,
ShaderDefinitions definitions,
ResourceManager resourceManager,
bool precise,
bool debugMode)
{
Info = new StructuredProgramInfo();
Info = new StructuredProgramInfo(precise);
Definitions = definitions;
ResourceManager = resourceManager;

View file

@ -10,11 +10,16 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
public HelperFunctionsMask HelperFunctionsMask { get; set; }
public StructuredProgramInfo()
public StructuredProgramInfo(bool precise)
{
Functions = new List<StructuredFunction>();
IoDefinitions = new HashSet<IoDefinition>();
if (precise)
{
HelperFunctionsMask |= HelperFunctionsMask.Precise;
}
}
}
}

View file

@ -26,5 +26,6 @@ namespace Ryujinx.Graphics.Shader.Translation
SharedMemory = 1 << 11,
Store = 1 << 12,
VtgAsCompute = 1 << 13,
Precise = 1 << 14,
}
}

View file

@ -43,6 +43,11 @@ namespace Ryujinx.Graphics.Shader.Translation
private readonly Dictionary<TextureInfo, TextureMeta> _usedTextures;
private readonly Dictionary<TextureInfo, TextureMeta> _usedImages;
private readonly List<BufferDefinition> _vacConstantBuffers;
private readonly List<BufferDefinition> _vacStorageBuffers;
private readonly List<TextureDefinition> _vacTextures;
private readonly List<TextureDefinition> _vacImages;
public int LocalMemoryId { get; private set; }
public int SharedMemoryId { get; private set; }
@ -78,6 +83,11 @@ namespace Ryujinx.Graphics.Shader.Translation
_usedTextures = new();
_usedImages = new();
_vacConstantBuffers = new();
_vacStorageBuffers = new();
_vacTextures = new();
_vacImages = new();
Properties.AddOrUpdateConstantBuffer(new(BufferLayout.Std140, 0, SupportBuffer.Binding, "support_buffer", SupportBuffer.GetStructureType()));
LocalMemoryId = -1;
@ -563,6 +573,75 @@ namespace Ryujinx.Graphics.Shader.Translation
return descriptors.ToArray();
}
public ShaderProgramInfo GetVertexAsComputeInfo(bool isVertex = false)
{
var cbDescriptors = new BufferDescriptor[_vacConstantBuffers.Count];
int cbDescriptorIndex = 0;
foreach (BufferDefinition definition in _vacConstantBuffers)
{
cbDescriptors[cbDescriptorIndex++] = new BufferDescriptor(definition.Set, definition.Binding, 0, 0, 0, BufferUsageFlags.None);
}
var sbDescriptors = new BufferDescriptor[_vacStorageBuffers.Count];
int sbDescriptorIndex = 0;
foreach (BufferDefinition definition in _vacStorageBuffers)
{
sbDescriptors[sbDescriptorIndex++] = new BufferDescriptor(definition.Set, definition.Binding, 0, 0, 0, BufferUsageFlags.Write);
}
var tDescriptors = new TextureDescriptor[_vacTextures.Count];
int tDescriptorIndex = 0;
foreach (TextureDefinition definition in _vacTextures)
{
tDescriptors[tDescriptorIndex++] = new TextureDescriptor(
definition.Set,
definition.Binding,
definition.Type,
definition.Format,
0,
0,
definition.ArrayLength,
definition.Separate,
definition.Flags);
}
var iDescriptors = new TextureDescriptor[_vacImages.Count];
int iDescriptorIndex = 0;
foreach (TextureDefinition definition in _vacImages)
{
iDescriptors[iDescriptorIndex++] = new TextureDescriptor(
definition.Set,
definition.Binding,
definition.Type,
definition.Format,
0,
0,
definition.ArrayLength,
definition.Separate,
definition.Flags);
}
return new ShaderProgramInfo(
cbDescriptors,
sbDescriptors,
tDescriptors,
iDescriptors,
isVertex ? ShaderStage.Vertex : ShaderStage.Compute,
0,
0,
0,
false,
false,
false,
false,
0,
0);
}
public bool TryGetCbufSlotAndHandleForTexture(int binding, out int cbufSlot, out int handle)
{
foreach ((TextureInfo info, TextureMeta meta) in _usedTextures)
@ -629,6 +708,30 @@ namespace Ryujinx.Graphics.Shader.Translation
Properties.AddOrUpdateStorageBuffer(new(BufferLayout.Std430, setIndex, binding, name, type));
}
public void AddVertexAsComputeConstantBuffer(BufferDefinition definition)
{
_vacConstantBuffers.Add(definition);
Properties.AddOrUpdateConstantBuffer(definition);
}
public void AddVertexAsComputeStorageBuffer(BufferDefinition definition)
{
_vacStorageBuffers.Add(definition);
Properties.AddOrUpdateStorageBuffer(definition);
}
public void AddVertexAsComputeTexture(TextureDefinition definition)
{
_vacTextures.Add(definition);
Properties.AddOrUpdateTexture(definition);
}
public void AddVertexAsComputeImage(TextureDefinition definition)
{
_vacImages.Add(definition);
Properties.AddOrUpdateImage(definition);
}
public static string GetShaderStagePrefix(ShaderStage stage)
{
uint index = (uint)stage;

View file

@ -4,5 +4,6 @@ namespace Ryujinx.Graphics.Shader.Translation
{
OpenGL,
Vulkan,
Metal
}
}

View file

@ -4,6 +4,6 @@ namespace Ryujinx.Graphics.Shader.Translation
{
Glsl,
Spirv,
Arb,
Msl
}
}

View file

@ -27,6 +27,8 @@ namespace Ryujinx.Graphics.Shader.Translation.Transforms
addOp.Inst == (Instruction.FP32 | Instruction.Add) &&
addOp.GetSource(1).Type == OperandType.Constant)
{
context.UsedFeatures |= FeatureFlags.Precise;
addOp.ForcePrecise = true;
}

View file

@ -1,5 +1,6 @@
using Ryujinx.Graphics.Shader.CodeGen;
using Ryujinx.Graphics.Shader.CodeGen.Glsl;
using Ryujinx.Graphics.Shader.CodeGen.Msl;
using Ryujinx.Graphics.Shader.CodeGen.Spirv;
using Ryujinx.Graphics.Shader.Decoders;
using Ryujinx.Graphics.Shader.IntermediateRepresentation;
@ -331,6 +332,7 @@ namespace Ryujinx.Graphics.Shader.Translation
definitions,
resourceManager,
Options.TargetLanguage,
usedFeatures.HasFlag(FeatureFlags.Precise),
Options.Flags.HasFlag(TranslationFlags.DebugMode));
int geometryVerticesPerPrimitive = Definitions.OutputTopology switch
@ -373,6 +375,7 @@ namespace Ryujinx.Graphics.Shader.Translation
{
TargetLanguage.Glsl => new ShaderProgram(info, TargetLanguage.Glsl, GlslGenerator.Generate(sInfo, parameters)),
TargetLanguage.Spirv => new ShaderProgram(info, TargetLanguage.Spirv, SpirvGenerator.Generate(sInfo, parameters)),
TargetLanguage.Msl => new ShaderProgram(info, TargetLanguage.Msl, MslGenerator.Generate(sInfo, parameters)),
_ => throw new NotImplementedException(Options.TargetLanguage.ToString()),
};
}
@ -392,7 +395,7 @@ namespace Ryujinx.Graphics.Shader.Translation
{
int binding = resourceManager.Reservations.GetTfeBufferStorageBufferBinding(i);
BufferDefinition tfeDataBuffer = new(BufferLayout.Std430, 1, binding, $"tfe_data{i}", tfeDataStruct);
resourceManager.Properties.AddOrUpdateStorageBuffer(tfeDataBuffer);
resourceManager.AddVertexAsComputeStorageBuffer(tfeDataBuffer);
}
}
@ -400,7 +403,7 @@ namespace Ryujinx.Graphics.Shader.Translation
{
int vertexInfoCbBinding = resourceManager.Reservations.VertexInfoConstantBufferBinding;
BufferDefinition vertexInfoBuffer = new(BufferLayout.Std140, 0, vertexInfoCbBinding, "vb_info", VertexInfoBuffer.GetStructureType());
resourceManager.Properties.AddOrUpdateConstantBuffer(vertexInfoBuffer);
resourceManager.AddVertexAsComputeConstantBuffer(vertexInfoBuffer);
StructureType vertexOutputStruct = new(new StructureField[]
{
@ -409,13 +412,13 @@ namespace Ryujinx.Graphics.Shader.Translation
int vertexOutputSbBinding = resourceManager.Reservations.VertexOutputStorageBufferBinding;
BufferDefinition vertexOutputBuffer = new(BufferLayout.Std430, 1, vertexOutputSbBinding, "vertex_output", vertexOutputStruct);
resourceManager.Properties.AddOrUpdateStorageBuffer(vertexOutputBuffer);
resourceManager.AddVertexAsComputeStorageBuffer(vertexOutputBuffer);
if (Stage == ShaderStage.Vertex)
{
SetBindingPair ibSetAndBinding = resourceManager.Reservations.GetIndexBufferTextureSetAndBinding();
TextureDefinition indexBuffer = new(ibSetAndBinding.SetIndex, ibSetAndBinding.Binding, "ib_data", SamplerType.TextureBuffer);
resourceManager.Properties.AddOrUpdateTexture(indexBuffer);
resourceManager.AddVertexAsComputeTexture(indexBuffer);
int inputMap = _program.AttributeUsage.UsedInputAttributes;
@ -424,7 +427,7 @@ namespace Ryujinx.Graphics.Shader.Translation
int location = BitOperations.TrailingZeroCount(inputMap);
SetBindingPair setAndBinding = resourceManager.Reservations.GetVertexBufferTextureSetAndBinding(location);
TextureDefinition vaBuffer = new(setAndBinding.SetIndex, setAndBinding.Binding, $"vb_data{location}", SamplerType.TextureBuffer);
resourceManager.Properties.AddOrUpdateTexture(vaBuffer);
resourceManager.AddVertexAsComputeTexture(vaBuffer);
inputMap &= ~(1 << location);
}
@ -433,11 +436,11 @@ namespace Ryujinx.Graphics.Shader.Translation
{
SetBindingPair trbSetAndBinding = resourceManager.Reservations.GetTopologyRemapBufferTextureSetAndBinding();
TextureDefinition remapBuffer = new(trbSetAndBinding.SetIndex, trbSetAndBinding.Binding, "trb_data", SamplerType.TextureBuffer);
resourceManager.Properties.AddOrUpdateTexture(remapBuffer);
resourceManager.AddVertexAsComputeTexture(remapBuffer);
int geometryVbOutputSbBinding = resourceManager.Reservations.GeometryVertexOutputStorageBufferBinding;
BufferDefinition geometryVbOutputBuffer = new(BufferLayout.Std430, 1, geometryVbOutputSbBinding, "geometry_vb_output", vertexOutputStruct);
resourceManager.Properties.AddOrUpdateStorageBuffer(geometryVbOutputBuffer);
resourceManager.AddVertexAsComputeStorageBuffer(geometryVbOutputBuffer);
StructureType geometryIbOutputStruct = new(new StructureField[]
{
@ -446,7 +449,7 @@ namespace Ryujinx.Graphics.Shader.Translation
int geometryIbOutputSbBinding = resourceManager.Reservations.GeometryIndexOutputStorageBufferBinding;
BufferDefinition geometryIbOutputBuffer = new(BufferLayout.Std430, 1, geometryIbOutputSbBinding, "geometry_ib_output", geometryIbOutputStruct);
resourceManager.Properties.AddOrUpdateStorageBuffer(geometryIbOutputBuffer);
resourceManager.AddVertexAsComputeStorageBuffer(geometryIbOutputBuffer);
}
resourceManager.SetVertexAsComputeLocalMemories(Definitions.Stage, Definitions.InputTopology);
@ -479,12 +482,17 @@ namespace Ryujinx.Graphics.Shader.Translation
return new ResourceReservations(GpuAccessor, IsTransformFeedbackEmulated, vertexAsCompute: true, _vertexOutput, ioUsage);
}
public ShaderProgramInfo GetVertexAsComputeInfo()
{
return CreateResourceManager(true).GetVertexAsComputeInfo();
}
public void SetVertexOutputMapForGeometryAsCompute(TranslatorContext vertexContext)
{
_vertexOutput = vertexContext._program.GetIoUsage();
}
public ShaderProgram GenerateVertexPassthroughForCompute()
public (ShaderProgram, ShaderProgramInfo) GenerateVertexPassthroughForCompute()
{
var attributeUsage = new AttributeUsage(GpuAccessor);
var resourceManager = new ResourceManager(ShaderStage.Vertex, GpuAccessor);
@ -496,7 +504,7 @@ namespace Ryujinx.Graphics.Shader.Translation
if (Stage == ShaderStage.Vertex)
{
BufferDefinition vertexInfoBuffer = new(BufferLayout.Std140, 0, vertexInfoCbBinding, "vb_info", VertexInfoBuffer.GetStructureType());
resourceManager.Properties.AddOrUpdateConstantBuffer(vertexInfoBuffer);
resourceManager.AddVertexAsComputeConstantBuffer(vertexInfoBuffer);
}
StructureType vertexInputStruct = new(new StructureField[]
@ -506,7 +514,7 @@ namespace Ryujinx.Graphics.Shader.Translation
int vertexDataSbBinding = reservations.VertexOutputStorageBufferBinding;
BufferDefinition vertexOutputBuffer = new(BufferLayout.Std430, 1, vertexDataSbBinding, "vb_input", vertexInputStruct);
resourceManager.Properties.AddOrUpdateStorageBuffer(vertexOutputBuffer);
resourceManager.AddVertexAsComputeStorageBuffer(vertexOutputBuffer);
var context = new EmitterContext();
@ -564,14 +572,14 @@ namespace Ryujinx.Graphics.Shader.Translation
LastInVertexPipeline = true
};
return Generate(
return (Generate(
new[] { function },
attributeUsage,
definitions,
definitions,
resourceManager,
FeatureFlags.None,
0);
0), resourceManager.GetVertexAsComputeInfo(isVertex: true));
}
public ShaderProgram GenerateGeometryPassthrough()