Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][ABI] Add X86_64 float and double CC lowering #714

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/lib/CIR/Dialect/Transforms/CallConvLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ LowerModule createLowerModule(FuncOp op, PatternRewriter &rewriter) {
auto context = CIRLowerContext(module, langOpts);
context.initBuiltinTypes(*targetInfo);

return LowerModule(context, module, dataLayoutStr, *targetInfo, rewriter);
return LowerModule(context, module, dataLayoutStr, std::move(targetInfo), rewriter);
}

} // namespace
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/ABIInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "CIRCXXABI.h"
#include "CIRLowerContext.h"
#include "LowerTypes.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"

namespace mlir {
namespace cir {
Expand All @@ -26,6 +27,10 @@ CIRCXXABI &ABIInfo::getCXXABI() const { return LT.getCXXABI(); }

CIRLowerContext &ABIInfo::getContext() const { return LT.getContext(); }

const ::cir::CIRDataLayout &ABIInfo::getDataLayout() const {
return LT.getDataLayout();
}

bool ABIInfo::isPromotableIntegerTypeForABI(Type Ty) const {
if (getContext().isPromotableIntegerType(Ty))
return true;
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/CIR/Dialect/Transforms/TargetLowering/ABIInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "CIRCXXABI.h"
#include "CIRLowerContext.h"
#include "LowerFunctionInfo.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
#include "llvm/IR/CallingConv.h"

namespace mlir {
Expand All @@ -38,7 +39,9 @@ class ABIInfo {

CIRCXXABI &getCXXABI() const;

CIRLowerContext &getContext() const;
CIRLowerContext &getContext() const;\

const ::cir::CIRDataLayout &getDataLayout() const;

virtual void computeInfo(LowerFunctionInfo &FI) const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ clang::TypeInfo CIRLowerContext::getTypeInfoImpl(const Type T) const {
// TODO(cir): We should implement a better way to identify type kinds and use
// builting data layout interface for this.
auto typeKind = clang::Type::Builtin;
if (isa<IntType>(T)) {
if (isa<IntType, SingleType, DoubleType>(T)) {
typeKind = clang::Type::Builtin;
} else {
llvm_unreachable("Unhandled type class");
Expand All @@ -74,6 +74,16 @@ clang::TypeInfo CIRLowerContext::getTypeInfoImpl(const Type T) const {
Align = std::ceil((float)Width / 8) * 8;
break;
}
if (auto floatTy = dyn_cast<SingleType>(T)) {
Width = Target->getFloatWidth();
Align = Target->getFloatAlign();
break;
}
if (auto doubleTy = dyn_cast<DoubleType>(T)) {
Width = Target->getDoubleWidth();
Align = Target->getDoubleAlign();
break;
}
llvm_unreachable("Unknown builtin type!");
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ createTargetLoweringInfo(LowerModule &LM) {
}

LowerModule::LowerModule(CIRLowerContext &C, ModuleOp &module, StringAttr DL,
const clang::TargetInfo &target,
std::unique_ptr<clang::TargetInfo> target,
PatternRewriter &rewriter)
: context(C), module(module), Target(target), ABI(createCXXABI(*this)),
types(*this, DL.getValue()), rewriter(rewriter) {}
: context(C), module(module), Target(std::move(target)),
ABI(createCXXABI(*this)), types(*this, DL.getValue()),
rewriter(rewriter) {}

const TargetLoweringInfo &LowerModule::getTargetLoweringInfo() {
if (!TheTargetCodeGenInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace cir {
class LowerModule {
CIRLowerContext &context;
ModuleOp module;
const clang::TargetInfo &Target;
const std::unique_ptr<clang::TargetInfo> Target;
mutable std::unique_ptr<TargetLoweringInfo> TheTargetCodeGenInfo;
std::unique_ptr<CIRCXXABI> ABI;

Expand All @@ -41,21 +41,22 @@ class LowerModule {

public:
LowerModule(CIRLowerContext &C, ModuleOp &module, StringAttr DL,
const clang::TargetInfo &target, PatternRewriter &rewriter);
std::unique_ptr<clang::TargetInfo> target,
PatternRewriter &rewriter);
~LowerModule() = default;

// Trivial getters.
LowerTypes &getTypes() { return types; }
CIRLowerContext &getContext() { return context; }
CIRCXXABI &getCXXABI() const { return *ABI; }
const clang::TargetInfo &getTarget() const { return Target; }
const clang::TargetInfo &getTarget() const { return *Target; }
MLIRContext *getMLIRContext() { return module.getContext(); }
ModuleOp &getModule() { return module; }

const TargetLoweringInfo &getTargetLoweringInfo();

// FIXME(cir): This would be in ASTContext, not CodeGenModule.
const clang::TargetInfo &getTargetInfo() const { return Target; }
const clang::TargetInfo &getTargetInfo() const { return *Target; }

// FIXME(cir): This would be in ASTContext, not CodeGenModule.
clang::TargetCXXABI::Kind getCXXABIKind() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class LowerTypes {
LowerTypes(LowerModule &LM, StringRef DLString);
~LowerTypes() = default;

const ::cir::CIRDataLayout &getDataLayout() const { return DL; }
LowerModule &getLM() const { return LM; }
CIRCXXABI &getCXXABI() const { return CXXABI; }
CIRLowerContext &getContext() { return context; }
Expand Down
62 changes: 62 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/Targets/X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "LowerTypes.h"
#include "TargetInfo.h"
#include "clang/CIR/ABIArgInfo.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/Support/ErrorHandling.h"
#include <memory>
Expand Down Expand Up @@ -37,6 +38,15 @@ static bool BitsContainNoUserData(Type Ty, unsigned StartBit, unsigned EndBit,
llvm_unreachable("NYI");
}

/// Return a floating point type at the specified offset.
Type getFPTypeAtOffset(Type IRType, unsigned IROffset,
const ::cir::CIRDataLayout &TD) {
if (IROffset == 0 && isa<SingleType, DoubleType>(IRType))
return IRType;

llvm_unreachable("NYI");
}

} // namespace

class X86_64ABIInfo : public ABIInfo {
Expand Down Expand Up @@ -71,6 +81,9 @@ class X86_64ABIInfo : public ABIInfo {
void classify(Type T, uint64_t OffsetBase, Class &Lo, Class &Hi,
bool isNamedArg, bool IsRegCall = false) const;

Type GetSSETypeAtOffset(Type IRType, unsigned IROffset, Type SourceTy,
unsigned SourceOffset) const;

Type GetINTEGERTypeAtOffset(Type DestTy, unsigned IROffset, Type SourceTy,
unsigned SourceOffset) const;

Expand Down Expand Up @@ -125,6 +138,10 @@ void X86_64ABIInfo::classify(Type Ty, uint64_t OffsetBase, Class &Lo, Class &Hi,
}
return;

} else if (isa<SingleType>(Ty) || isa<DoubleType>(Ty)) {
Current = Class::SSE;
return;

} else {
llvm::outs() << "Missing X86 classification for type " << Ty << "\n";
llvm_unreachable("NYI");
Expand All @@ -138,6 +155,37 @@ void X86_64ABIInfo::classify(Type Ty, uint64_t OffsetBase, Class &Lo, Class &Hi,
llvm_unreachable("NYI");
}

/// Return a type that will be passed by the backend in the low 8 bytes of an
/// XMM register, corresponding to the SSE class.
Type X86_64ABIInfo::GetSSETypeAtOffset(Type IRType, unsigned IROffset,
Type SourceTy,
unsigned SourceOffset) const {
const ::cir::CIRDataLayout &TD = getDataLayout();
unsigned SourceSize =
(unsigned)getContext().getTypeSize(SourceTy) / 8 - SourceOffset;
Type T0 = getFPTypeAtOffset(IRType, IROffset, TD);
if (!T0 || isa<Float64Type>(T0))
return T0; // NOTE(cir): Not sure if this is correct.

Type T1 = {};
unsigned T0Size = TD.getTypeAllocSize(T0);
if (SourceSize > T0Size)
llvm_unreachable("NYI");
if (T1 == nullptr) {
// Check if IRType is a half/bfloat + float. float type will be in
// IROffset+4 due to its alignment.
if (isa<Float16Type>(T0) && SourceSize > 4)
llvm_unreachable("NYI");
// If we can't get a second FP type, return a simple half or float.
// avx512fp16-abi.c:pr51813_2 shows it works to return float for
// {float, i8} too.
if (T1 == nullptr)
return T0;
}

llvm_unreachable("NYI");
}

/// The ABI specifies that a value should be passed in an 8-byte GPR. This
/// means that we either have a scalar or we are talking about the high or low
/// part of an up-to-16-byte struct. This routine picks the best CIR type
Expand Down Expand Up @@ -236,6 +284,12 @@ ::cir::ABIArgInfo X86_64ABIInfo::classifyReturnType(Type RetTy) const {
}
break;

// AMD64-ABI 3.2.3p4: Rule 4. If the class is SSE, the next
// available SSE register of the sequence %xmm0, %xmm1 is used.
case Class::SSE:
resType = GetSSETypeAtOffset(RetTy, 0, RetTy, 0);
break;

default:
llvm_unreachable("NYI");
}
Expand Down Expand Up @@ -302,6 +356,14 @@ ABIArgInfo X86_64ABIInfo::classifyArgumentType(Type Ty, unsigned freeIntRegs,

break;

// AMD64-ABI 3.2.3p3: Rule 3. If the class is SSE, the next
// available SSE register is used, the registers are taken in the
// order from %xmm0 to %xmm7.
case Class::SSE: {
ResType = GetSSETypeAtOffset(Ty, 0, Ty, 0);
++neededSSE;
break;
}
default:
llvm_unreachable("NYI");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,16 @@ long long LongLong(long long l) {
// CHECK: cir.call @_Z8LongLongx(%{{.+}}) : (!s64i) -> !s64i
return LongLong(l);
}

/// Test call conv lowering for floating point. ///

// CHECK: cir.func @_Z5Floatf(%arg0: !cir.float loc({{.+}})) -> !cir.float
float Float(float f) {
// cir.call @_Z5Floatf(%{{.+}}) : (!cir.float) -> !cir.float
return Float(f);
}
// CHECK: cir.func @_Z6Doubled(%arg0: !cir.double loc({{.+}})) -> !cir.double
double Double(double d) {
// cir.call @_Z6Doubled(%{{.+}}) : (!cir.double) -> !cir.double
return Double(d);
}
Loading