Skip to content

Commit

Permalink
Add support for slices of slices
Browse files Browse the repository at this point in the history
  • Loading branch information
cesaref committed Nov 7, 2024
1 parent 53fbcd2 commit 308fa8c
Show file tree
Hide file tree
Showing 18 changed files with 305 additions and 67 deletions.
2 changes: 1 addition & 1 deletion modules/compiler/src/AST/cmaj_AST.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ using ref = choc::ObjectReference<Type>;
X(ForwardBranch) \
X(Function) \
X(FunctionCall) \
X(GetArraySlice) \
X(GetArrayOrVectorSlice) \
X(GetElement) \
X(GetStructMember) \
X(Graph) \
Expand Down
11 changes: 8 additions & 3 deletions modules/compiler/src/AST/cmaj_AST_Classes_Values.h
Original file line number Diff line number Diff line change
Expand Up @@ -1370,11 +1370,11 @@ struct GetElement : public ValueBase
};

//==============================================================================
struct GetArraySlice : public ValueBase
struct GetArrayOrVectorSlice : public ValueBase
{
GetArraySlice (const ObjectContext& c) : ValueBase (c) {}
GetArrayOrVectorSlice (const ObjectContext& c) : ValueBase (c) {}

CMAJ_AST_DECLARE_STANDARD_METHODS(GetArraySlice, 34)
CMAJ_AST_DECLARE_STANDARD_METHODS(GetArrayOrVectorSlice, 34)

ptr<VariableDeclaration> getSourceVariable() const override
{
Expand Down Expand Up @@ -1404,6 +1404,11 @@ struct GetArraySlice : public ValueBase
return result;
}
}
else
{
if (auto arrayType = parentType->getAsArrayType())
return createSliceOfType (context, arrayType->getInnermostElementTypeRef());
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions modules/compiler/src/AST/cmaj_AST_StringPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ struct Strings
out { stringPool.get ("out") },
value { stringPool.get ("value") },
index { stringPool.get ("index") },
start { stringPool.get ("start") },
end { stringPool.get ("end") },
array { stringPool.get ("array") },
run { stringPool.get ("run") },
increment { stringPool.get ("increment") },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,15 @@ struct EndpointInfo
return createReaderNoParensNeeded (typeName + " { " + sourceExpression + " }");
}

ValueReader createSliceOfSlice (const AST::TypeBase& elementType, ValueReader sourceSlice, ValueReader start, ValueReader end)
{
(void) elementType;

return createReaderNoParensNeeded (sourceSlice.getWithoutParens() + ".slice ("
+ start.getWithoutParens() + ", "
+ end.getWithoutParens() + ")");
}

ValueReader createSliceFromArray (const AST::TypeBase& elementType, ValueReader sourceArray,
uint32_t offset, uint32_t numElements)
{
Expand Down Expand Up @@ -2010,9 +2019,13 @@ struct Slice
ElementType operator[] (IndexType index) const noexcept { return numElements == 0 ? ElementType() : elements[index]; }
ElementType& operator[] (IndexType index) noexcept { return numElements == 0 ? emptyValue : elements[index]; }
Slice slice (IndexType start, IndexType end) noexcept { if (numElements == 0) return {}; return { elements + start, end }; }
Slice sliceFrom (IndexType start) noexcept { if (numElements == 0) return {}; return { elements + start, numElements - start }; }
Slice sliceUpTo (IndexType end) noexcept { if (numElements == 0) return {}; return { elements, end }; }
Slice slice (IndexType start, IndexType end) noexcept
{
if (numElements == 0) return {};
if (start >= numElements) return {};
return { elements + start, std::min (static_cast<SizeType> (end - start), numElements - start) };
}
ElementType* elements = nullptr;
SizeType numElements = 0;
Expand Down
26 changes: 26 additions & 0 deletions modules/compiler/src/backends/LLVM/cmaj_LLVMGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,32 @@ struct LLVMCodeGenerator
return makeReader (fatPointer, AST::createSliceOfType (elementType.context, elementType));
}

ValueReader createSliceOfSlice (const AST::TypeBase& elementType, ValueReader sourceSlice, ValueReader start, ValueReader end)
{
auto sourcePtr = getPointer (sourceSlice);

auto fatPointerType = getFatPointerType (elementType);
auto fatPointer = functionEntryBlockBuilder->CreateAlloca (fatPointerType);

auto& b = getBlockBuilder();

auto dataField = b.CreateConstInBoundsGEP2_32 (fatPointerType, sourcePtr, 0, 0);
auto startIndex = b.CreateLoad (getInt32Type(), getPointer (start));

auto sliceLength = createBinaryOp (AST::BinaryOpTypeEnum::Enum::subtract,
AST::TypeRules::getBinaryOperatorTypes (AST::BinaryOpTypeEnum::Enum::subtract, allocator.int32Type, allocator.int32Type),
end,
start);

auto ptr = b.CreateLoad (getLLVMType (elementType)->getPointerTo(), dataField);
auto sliceDataField = b.CreateInBoundsGEP (getLLVMType (elementType), ptr, { startIndex });

b.CreateStore (sliceDataField, b.CreateConstInBoundsGEP2_32 (fatPointerType, fatPointer, 0, 0));
b.CreateStore (sliceLength.value, b.CreateConstInBoundsGEP2_32 (fatPointerType, fatPointer, 0, 1));

return makeReader (fatPointer, AST::createSliceOfType (elementType.context, elementType));
}

ValueReader createGetSliceSize (ValueReader slice)
{
auto& b = getBlockBuilder();
Expand Down
30 changes: 22 additions & 8 deletions modules/compiler/src/codegen/cmaj_CodeGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ struct CodeGenerator

void emitAssignment (const AST::Object& target, const AST::Property& newValue)
{
if (auto s = target.getAsGetArraySlice())
if (auto s = target.getAsGetArrayOrVectorSlice())
return emitWriteToSlice (*s, AST::castToValueRef (newValue));

auto targetType = target.getAsValueBase()->getResultType();
Expand All @@ -372,7 +372,7 @@ struct CodeGenerator
if (auto v = target.getAsVariableDeclaration()) return emitWriteToVariable (*v, std::move (newValue));
if (auto v = target.getAsNamedReference()) return emitAssignment (v->target, std::move (newValue), canSourceBeReadMultipleTimes);
if (auto e = target.getAsGetElement()) return emitWriteToElement (*e, std::move (newValue));
if (auto s = target.getAsGetArraySlice()) return emitWriteToSlice (*s, std::move (newValue), false, canSourceBeReadMultipleTimes);
if (auto s = target.getAsGetArrayOrVectorSlice()) return emitWriteToSlice (*s, std::move (newValue), false, canSourceBeReadMultipleTimes);
if (auto m = target.getAsGetStructMember()) return emitWriteToStructMember (*m, std::move (newValue));
}

Expand Down Expand Up @@ -439,7 +439,7 @@ struct CodeGenerator
return v.isCompileTimeConstant() || v.isVariableReference();
}

void emitWriteToSlice (const AST::GetArraySlice& slice, ValueReader newValue,
void emitWriteToSlice (const AST::GetArrayOrVectorSlice& slice, ValueReader newValue,
bool valueIsElement, bool canSourceBeReadMultipleTimes)
{
auto parentSize = slice.getParentSize();
Expand Down Expand Up @@ -467,7 +467,7 @@ struct CodeGenerator
});
}

void emitWriteToSlice (const AST::GetArraySlice& slice, const AST::ValueBase& newValue)
void emitWriteToSlice (const AST::GetArrayOrVectorSlice& slice, const AST::ValueBase& newValue)
{
if (auto v = newValue.getAsConstantAggregate())
{
Expand Down Expand Up @@ -696,7 +696,7 @@ struct CodeGenerator
if (auto v = value.getAsVariableDeclaration()) return builder.createVariableReader (*v);
if (auto v = value.getAsNamedReference()) return createValueReader (v->target);
if (auto e = value.getAsGetElement()) return createElementReader (*e);
if (auto s = value.getAsGetArraySlice()) return createSliceReader (*s);
if (auto s = value.getAsGetArrayOrVectorSlice()) return createSliceReader (*s);
if (auto m = value.getAsGetStructMember()) return createStructMemberReader (*m);
if (auto m = value.getAsValueMetaFunction()) return createValueMetaFunction (*m);
if (auto p = value.getAsPreOrPostIncOrDec()) return createPreOrPostIncOrDec (AST::castToValueRef (p->target), p->isIncrement, p->isPost);
Expand Down Expand Up @@ -753,7 +753,7 @@ struct CodeGenerator
}
else
{
CMAJ_ASSERT (isConst || value.getAsGetArraySlice() != nullptr);
CMAJ_ASSERT (isConst || value.getAsGetArrayOrVectorSlice() != nullptr);
}

auto& v = AST::castToValueRef (value);
Expand Down Expand Up @@ -1171,7 +1171,7 @@ struct CodeGenerator

ValueReader createSliceFromValue (const AST::TypeBase& elementType, const AST::ArrayType& sourceType, const AST::ValueBase& value)
{
if (auto slice = value.getAsGetArraySlice())
if (auto slice = value.getAsGetArrayOrVectorSlice())
{
auto& parentArray = AST::castToValueRef (slice->parent);
auto& parentArrayType = *parentArray.getResultType();
Expand Down Expand Up @@ -1362,9 +1362,23 @@ struct CodeGenerator
m.member.get(), structType.indexOfMember (m.member.get()));
}

ValueReader createSliceReader (const AST::GetArraySlice& slice)
ValueReader createSliceReader (const AST::GetArrayOrVectorSlice& slice)
{
auto parentSize = slice.getParentSize();

if (! parentSize.has_value())
{
auto& parentArray = AST::castToValueRef (slice.parent);
auto& parentArrayType = *parentArray.getResultType();
auto elementType = parentArrayType.getArrayOrVectorElementType();

// Slice of Slice
return builder.createSliceOfSlice (*elementType,
createValueReader (parentArray),
createValueReader (slice.start),
createValueReader (slice.end));
}

CMAJ_ASSERT (parentSize.has_value());
auto sliceSize = slice.getSliceSize();
CMAJ_ASSERT (sliceSize.has_value());
Expand Down
2 changes: 1 addition & 1 deletion modules/compiler/src/codegen/cmaj_ProgramPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ struct ProgramPrinter : public SourceCodeFormattingHelper
.addPunctuation ("]");
}

if (auto slice = e.getAsGetArraySlice())
if (auto slice = e.getAsGetArrayOrVectorSlice())
return formatExpression (slice->parent).addParensIfNeeded()
.add (formatRangePair (slice->start, slice->end));

Expand Down
2 changes: 1 addition & 1 deletion modules/compiler/src/passes/cmaj_ConstantFolder.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct ConstantFolder : public PassAvoidingGenericFunctionsAndModules
void visit (AST::BinaryOperator& o) override { super::visit (o); fold (o); }
void visit (AST::TernaryOperator& o) override { super::visit (o); fold (o); }
void visit (AST::GetElement& o) override { super::visit (o); fold (o.indexes); fold (o); }
void visit (AST::GetArraySlice& o) override { super::visit (o); fold (o.start); fold (o.end); fold (o); }
void visit (AST::GetArrayOrVectorSlice& o) override { super::visit (o); fold (o.start); fold (o.end); fold (o); }
void visit (AST::GetStructMember& o) override { super::visit (o); fold (o); }
void visit (AST::FunctionCall& o) override { super::visit (o); fold (o); }
void visit (AST::VectorType& o) override { super::visit (o); fold (o.numElements); }
Expand Down
2 changes: 1 addition & 1 deletion modules/compiler/src/passes/cmaj_StrengthReduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct StrengthReduction : public PassAvoidingGenericFunctionsAndModules

CMAJ_DO_NOT_VISIT_CONSTANTS

void visit (AST::GetArraySlice& o) override
void visit (AST::GetArrayOrVectorSlice& o) override
{
super::visit (o);

Expand Down
4 changes: 2 additions & 2 deletions modules/compiler/src/passes/cmaj_TypeResolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct TypeResolver : public PassAvoidingGenericFunctionsAndModules

if (term.isRange)
{
auto& getSlice = replaceWithNewObject<AST::GetArraySlice> (b);
auto& getSlice = replaceWithNewObject<AST::GetArrayOrVectorSlice> (b);
getSlice.parent.referTo (*v);

if (term.startIndex != nullptr) getSlice.start.referTo (term.startIndex);
Expand Down Expand Up @@ -129,7 +129,7 @@ struct TypeResolver : public PassAvoidingGenericFunctionsAndModules

if (term.isRange)
{
auto& getSlice = replaceWithNewObject<AST::GetArraySlice> (b);
auto& getSlice = replaceWithNewObject<AST::GetArrayOrVectorSlice> (b);
getSlice.parent.referTo (parent);

if (term.startIndex != nullptr) getSlice.start.referTo (term.startIndex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ inline void convertComplexTypes (AST::Program& program)
}
}

void visit (AST::GetArraySlice& s) override
void visit (AST::GetArrayOrVectorSlice& s) override
{
super::visit (s);

Expand Down Expand Up @@ -422,7 +422,7 @@ inline void convertComplexTypes (AST::Program& program)

size_t indexOffset = 0;

if (auto arraySlice = AST::castTo<AST::GetArraySlice> (firstArg))
if (auto arraySlice = AST::castTo<AST::GetArrayOrVectorSlice> (firstArg))
{
if (auto startPos = AST::castTo<AST::ConstantValueBase> (arraySlice->start))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ static inline void replaceMultidimensionalArrays (AST::Program& program)

auto& end = AST::foldToConstantIfPossible (AST::createAdd (context, flattenedOffset, *size));

auto& slice = g.context.allocate<AST::GetArraySlice>();
auto& slice = g.context.allocate<AST::GetArrayOrVectorSlice>();
slice.parent.referTo (g.parent.getObjectRef());
slice.start.setChildObject (flattenedOffset);
slice.end.setChildObject (end);
Expand Down
Loading

0 comments on commit 308fa8c

Please sign in to comment.