diff --git a/src/carnot/planner/compiler/analyzer.h b/src/carnot/planner/compiler/analyzer.h index 5825db56a08138b7c6e4ed6effaf6ea62702a644..3a610c876cc7da3ea59f471fa7bbb6e564b43315 100644 --- a/src/carnot/planner/compiler/analyzer.h +++ b/src/carnot/planner/compiler/analyzer.h @@ -48,9 +48,9 @@ class Analyzer : public RuleExecutor<IR> { unique_sink_names->AddRule<UniqueSinkNameRule>(); } - void CreateAddLimitToMemorySinkBatch() { - RuleBatch* limit_to_mem_sink = CreateRuleBatch<FailOnMax>("AddLimitToMemorySink", 2); - limit_to_mem_sink->AddRule<AddLimitToMemorySinkRule>(compiler_state_); + void CreateAddLimitToBatchResultSinkBatch() { + RuleBatch* limit_to_res_sink = CreateRuleBatch<FailOnMax>("AddLimitToBatchResultSink", 2); + limit_to_res_sink->AddRule<AddLimitToBatchResultSinkRule>(compiler_state_); } void CreateOperatorCompileTimeExpressionRuleBatch() { @@ -100,7 +100,7 @@ class Analyzer : public RuleExecutor<IR> { md_handler_ = MetadataHandler::Create(); CreateSourceAndMetadataResolutionBatch(); CreateUniqueSinkNamesBatch(); - CreateAddLimitToMemorySinkBatch(); + CreateAddLimitToBatchResultSinkBatch(); CreateOperatorCompileTimeExpressionRuleBatch(); CreateCombineConsecutiveMapsRule(); CreateDataTypeResolutionBatch(); diff --git a/src/carnot/planner/compiler/test_utils.h b/src/carnot/planner/compiler/test_utils.h index 84ad12d3e8eddd6d620207afc376902ddae5a067..3caaed71894fe9809961d18f6076a7c9b944f419 100644 --- a/src/carnot/planner/compiler/test_utils.h +++ b/src/carnot/planner/compiler/test_utils.h @@ -612,6 +612,13 @@ class OperatorTests : public ::testing::Test { return grpc_sink; } + GRPCSinkIR* MakeGRPCSink(OperatorIR* parent, std::string name, + const std::vector<std::string>& out_cols) { + GRPCSinkIR* grpc_sink = + graph->CreateNode<GRPCSinkIR>(ast, parent, name, out_cols).ConsumeValueOrDie(); + return grpc_sink; + } + GRPCSourceIR* MakeGRPCSource(const table_store::schema::Relation& relation) { GRPCSourceIR* grpc_src_group = graph->CreateNode<GRPCSourceIR>(ast, relation).ConsumeValueOrDie(); @@ -1165,9 +1172,14 @@ void CompareCloneNode(GRPCSourceIR* /*new_ir*/, GRPCSourceIR* /*old_ir*/, template <> void CompareCloneNode(GRPCSinkIR* new_ir, GRPCSinkIR* old_ir, const std::string& err_string) { - EXPECT_EQ(new_ir->destination_id(), old_ir->destination_id()) << err_string; - EXPECT_EQ(new_ir->destination_address(), old_ir->destination_address()) << err_string; EXPECT_EQ(new_ir->DestinationAddressSet(), old_ir->DestinationAddressSet()) << err_string; + EXPECT_EQ(new_ir->destination_address(), old_ir->destination_address()) << err_string; + EXPECT_EQ(new_ir->destination_id(), old_ir->destination_id()) << err_string; + EXPECT_EQ(new_ir->has_destination_id(), old_ir->has_destination_id()) << err_string; + EXPECT_EQ(new_ir->has_output_table(), old_ir->has_output_table()) << err_string; + EXPECT_EQ(new_ir->out_columns(), old_ir->out_columns()) << err_string; + EXPECT_EQ(new_ir->name(), old_ir->name()) << err_string; + EXPECT_EQ(new_ir->out_columns(), old_ir->out_columns()) << err_string; } template <> diff --git a/src/carnot/planner/ir/ir_nodes.cc b/src/carnot/planner/ir/ir_nodes.cc index ace82ec15d255b517dfdfbdfa1ff7d872b931c46..8bf74a4e7bf044f292cf6caefb548daf77043bbd 100644 --- a/src/carnot/planner/ir/ir_nodes.cc +++ b/src/carnot/planner/ir/ir_nodes.cc @@ -1336,8 +1336,11 @@ Status LimitIR::CopyFromNodeImpl(const IRNode* node, absl::flat_hash_map<const I Status GRPCSinkIR::CopyFromNodeImpl(const IRNode* node, absl::flat_hash_map<const IRNode*, IRNode*>*) { const GRPCSinkIR* grpc_sink = static_cast<const GRPCSinkIR*>(node); + sink_type_ = grpc_sink->sink_type_; destination_id_ = grpc_sink->destination_id_; destination_address_ = grpc_sink->destination_address_; + name_ = grpc_sink->name_; + out_columns_ = grpc_sink->out_columns_; return Status::OK(); } @@ -1419,7 +1422,20 @@ Status GRPCSinkIR::ToProto(planpb::Operator* op) const { auto pb = op->mutable_grpc_sink_op(); op->set_op_type(planpb::GRPC_SINK_OPERATOR); pb->set_address(destination_address()); - pb->set_grpc_source_id(destination_id()); + if (has_destination_id()) { + pb->set_grpc_source_id(destination_id()); + } else if (has_output_table()) { + pb->mutable_output_table()->set_table_name(name()); + DCHECK(IsRelationInit()); + for (size_t i = 0; i < relation().NumColumns(); ++i) { + pb->mutable_output_table()->add_column_names(relation().GetColumnName(i)); + pb->mutable_output_table()->add_column_types(relation().GetColumnType(i)); + pb->mutable_output_table()->add_column_semantic_types(relation().GetColumnSemanticType(i)); + } + } else { + return error::Internal( + "Error in GRPCSinkIR::ToProto: node has no output table or destination ID"); + } return Status::OK(); } diff --git a/src/carnot/planner/ir/ir_nodes.h b/src/carnot/planner/ir/ir_nodes.h index e1326fcf0ee5a5405e6d9270529093d6dc9ded30..9535973a30b0eae9eb2df981c3febc6c18fa2f44 100644 --- a/src/carnot/planner/ir/ir_nodes.h +++ b/src/carnot/planner/ir/ir_nodes.h @@ -1622,11 +1622,32 @@ class GRPCSinkIR : public OperatorIR { public: explicit GRPCSinkIR(int64_t id) : OperatorIR(id, IRNodeType::kGRPCSink) {} + enum GRPCSinkType { + kTypeNotSet = 0, + kInternal, + kExternal, + }; + + // Init function to call to create an internal GRPCSink, which sends an intermediate + // result to a corresponding GRPC Source. Status Init(OperatorIR* parent, int64_t destination_id) { PL_RETURN_IF_ERROR(AddParent(parent)); destination_id_ = destination_id; + sink_type_ = GRPCSinkType::kInternal; + return Status::OK(); + } + + // Init function to call to create an external, final result producing GRPCSink, which + // streams the output table to a non-Carnot destination (such as the query broker). + Status Init(OperatorIR* parent, const std::string& name, + const std::vector<std::string> out_columns) { + PL_RETURN_IF_ERROR(AddParent(parent)); + sink_type_ = GRPCSinkType::kExternal; + name_ = name; + out_columns_ = out_columns; return Status::OK(); } + Status ToProto(planpb::Operator* op_pb) const override; /** @@ -1635,16 +1656,30 @@ class GRPCSinkIR : public OperatorIR { * * Once the Distributed Plan is established, you should use DistributedDestinationID(). */ + bool has_destination_id() const { return sink_type_ == GRPCSinkType::kInternal; } int64_t destination_id() const { return destination_id_; } void SetDestinationID(int64_t destination_id) { destination_id_ = destination_id; } void SetDestinationAddress(const std::string& address) { destination_address_ = address; } const std::string& destination_address() const { return destination_address_; } bool DestinationAddressSet() const { return destination_address_ != ""; } + + bool has_output_table() const { return sink_type_ == GRPCSinkType::kExternal; } + std::string name() const { return name_; } + void set_name(const std::string& name) { name_ = name; } + // When out_columns_ is empty, the full input relation will be written to the sink. + const std::vector<std::string>& out_columns() const { return out_columns_; } + inline bool IsBlocking() const override { return true; } StatusOr<std::vector<absl::flat_hash_set<std::string>>> RequiredInputColumns() const override { - return error::Unimplemented("Unexpected call to GRPCSinkIR::RequiredInputColumns"); + if (sink_type_ != GRPCSinkType::kExternal) { + return error::Unimplemented("Unexpected call to GRPCSinkIR::RequiredInputColumns"); + } + DCHECK(IsRelationInit()); + auto out_cols = relation().col_names(); + absl::flat_hash_set<std::string> outputs{out_cols.begin(), out_cols.end()}; + return std::vector<absl::flat_hash_set<std::string>>{outputs}; } static constexpr bool FailOnResolveType() { return true; } @@ -1658,8 +1693,13 @@ class GRPCSinkIR : public OperatorIR { } private: - int64_t destination_id_ = -1; std::string destination_address_ = ""; + GRPCSinkType sink_type_ = GRPCSinkType::kTypeNotSet; + // Used when GRPCSinkType = kInternal. + int64_t destination_id_ = -1; + // Used when GRPCSinkType = kExternal. + std::string name_; + std::vector<std::string> out_columns_; }; /** diff --git a/src/carnot/planner/ir/ir_nodes_test.cc b/src/carnot/planner/ir/ir_nodes_test.cc index 7f68977438c7952a84904ad5fc8a747d396a489b..a8ad08de011cc6e2504307add0eb234c3f168b98 100644 --- a/src/carnot/planner/ir/ir_nodes_test.cc +++ b/src/carnot/planner/ir/ir_nodes_test.cc @@ -745,7 +745,7 @@ TEST_F(OperatorTests, swap_parent) { EXPECT_EQ(col3->ReferenceID().ConsumeValueOrDie(), parent_map->id()); } -TEST_F(OperatorTests, grpc_ops) { +TEST_F(OperatorTests, internal_grpc_ops) { int64_t grpc_id = 123; std::string source_grpc_address = "1111"; std::string sink_physical_id = "agent-xyz"; @@ -762,11 +762,22 @@ TEST_F(OperatorTests, grpc_ops) { MakeMemSink(grpc_src_group, "out"); grpc_src_group->SetGRPCAddress(source_grpc_address); + EXPECT_TRUE(grpc_sink->has_destination_id()); + EXPECT_FALSE(grpc_sink->has_output_table()); EXPECT_EQ(grpc_sink->destination_id(), grpc_id); EXPECT_OK(grpc_src_group->AddGRPCSink(grpc_sink)); EXPECT_EQ(grpc_src_group->source_id(), grpc_id); } +TEST_F(OperatorTests, external_grpc) { + MemorySourceIR* mem_src = MakeMemSource(); + GRPCSinkIR* grpc_sink = MakeGRPCSink(mem_src, "output_table", std::vector<std::string>{"outcol"}); + EXPECT_FALSE(grpc_sink->has_destination_id()); + EXPECT_TRUE(grpc_sink->has_output_table()); + EXPECT_EQ("output_table", grpc_sink->name()); + EXPECT_THAT(grpc_sink->out_columns(), ElementsAre("outcol")); +} + using CloneTests = OperatorTests; TEST_F(CloneTests, simple_clone) { auto mem_source = MakeMemSource(); @@ -851,7 +862,7 @@ TEST_F(CloneTests, grpc_source_group) { } } -TEST_F(CloneTests, grpc_sink) { +TEST_F(CloneTests, internal_grpc_sink) { auto mem_source = MakeMemSource(); GRPCSinkIR* grpc_sink = MakeGRPCSink(mem_source, 123); grpc_sink->SetDestinationAddress("1111"); @@ -868,6 +879,23 @@ TEST_F(CloneTests, grpc_sink) { } } +TEST_F(CloneTests, external_grpc_sink) { + auto mem_source = MakeMemSource(); + GRPCSinkIR* grpc_sink = MakeGRPCSink(mem_source, "output_table", std::vector<std::string>{"foo"}); + grpc_sink->SetDestinationAddress("1111"); + + auto out = graph->Clone(); + EXPECT_OK(out.status()); + std::unique_ptr<IR> cloned_ir = out.ConsumeValueOrDie(); + + ASSERT_EQ(graph->dag().TopologicalSort(), cloned_ir->dag().TopologicalSort()); + + // Make sure that all of the columns are now part of the new graph. + for (int64_t i : cloned_ir->dag().TopologicalSort()) { + CompareClone(cloned_ir->Get(i), graph->Get(i), absl::Substitute("For index $0", i)); + } +} + TEST_F(CloneTests, grpc_source) { auto grpc_source = MakeGRPCSource(MakeRelation()); MakeMemSink(grpc_source, "sup"); @@ -1021,7 +1049,7 @@ TEST_F(ToProtoTests, grpc_source_ir) { EXPECT_THAT(pb, EqualsProto(kExpectedGRPCSourcePb)); } -constexpr char kExpectedGRPCSinkPb[] = R"proto( +constexpr char kExpectedInternalGRPCSinkPb[] = R"proto( op_type: GRPC_SINK_OPERATOR grpc_sink_op { address: "$0" @@ -1029,10 +1057,9 @@ constexpr char kExpectedGRPCSinkPb[] = R"proto( } )proto"; -TEST_F(ToProtoTests, grpc_sink_ir) { +TEST_F(ToProtoTests, internal_grpc_sink_ir) { int64_t destination_id = 123; std::string grpc_address = "1111"; - std::string physical_id = "agent-aa"; auto mem_src = MakeMemSource(); auto grpc_sink = MakeGRPCSink(mem_src, destination_id); grpc_sink->SetDestinationAddress(grpc_address); @@ -1040,7 +1067,44 @@ TEST_F(ToProtoTests, grpc_sink_ir) { planpb::Operator pb; ASSERT_OK(grpc_sink->ToProto(&pb)); - EXPECT_THAT(pb, EqualsProto(absl::Substitute(kExpectedGRPCSinkPb, grpc_address, destination_id))); + EXPECT_THAT( + pb, EqualsProto(absl::Substitute(kExpectedInternalGRPCSinkPb, grpc_address, destination_id))); +} + +constexpr char kExpectedExternalGRPCSinkPb[] = R"proto( + op_type: GRPC_SINK_OPERATOR + grpc_sink_op { + address: "$0" + output_table { + table_name: "$1" + column_names: "count" + column_names: "cpu0" + column_names: "cpu1" + column_names: "cpu2" + column_types: INT64 + column_types: FLOAT64 + column_types: FLOAT64 + column_types: FLOAT64 + column_semantic_types: ST_NONE + column_semantic_types: ST_NONE + column_semantic_types: ST_NONE + column_semantic_types: ST_NONE + } + } +)proto"; + +TEST_F(ToProtoTests, external_grpc_sink_ir) { + std::string grpc_address = "1111"; + auto mem_src = MakeMemSource(); + GRPCSinkIR* grpc_sink = MakeGRPCSink(mem_src, "output_table", std::vector<std::string>{}); + ASSERT_OK(grpc_sink->SetRelation(MakeRelation())); + grpc_sink->SetDestinationAddress(grpc_address); + + planpb::Operator pb; + ASSERT_OK(grpc_sink->ToProto(&pb)); + + EXPECT_THAT( + pb, EqualsProto(absl::Substitute(kExpectedExternalGRPCSinkPb, grpc_address, "output_table"))); } constexpr char kIRProto[] = R"proto( @@ -1668,6 +1732,24 @@ TEST_F(OperatorTests, uint128_ir_init_from_str_bad_format) { EXPECT_THAT(uint128_or_s.status(), HasCompilerError(".* is not a valid UUID")); } +TEST_F(OperatorTests, grpc_sink_required_inputs) { + auto mem_source = + graph->CreateNode<MemorySourceIR>(ast, "source", std::vector<std::string>{}).ValueOrDie(); + auto sink = graph + ->CreateNode<GRPCSinkIR>(ast, mem_source, "output_table", + std::vector<std::string>{"output1", "output2"}) + .ValueOrDie(); + + auto rel = table_store::schema::Relation( + std::vector<types::DataType>({types::DataType::INT64, types::DataType::FLOAT64}), + std::vector<std::string>({"output1", "output2"})); + EXPECT_OK(sink->SetRelation(rel)); + + auto inputs = sink->RequiredInputColumns().ConsumeValueOrDie(); + EXPECT_EQ(1, inputs.size()); + EXPECT_THAT(inputs[0], UnorderedElementsAre("output1", "output2")); +} + TEST_F(OperatorTests, map_required_inputs) { MemorySourceIR* mem_source = MakeMemSource(); ColumnIR* col1 = MakeColumn("test1", /*parent_op_idx*/ 0); diff --git a/src/carnot/planner/ir/pattern_match.h b/src/carnot/planner/ir/pattern_match.h index fd3adb0b1e1805a4004b4851d808a9a8530030bf..e1da6a8bd2813ec6d4366efd028f46f1748bb481 100644 --- a/src/carnot/planner/ir/pattern_match.h +++ b/src/carnot/planner/ir/pattern_match.h @@ -155,6 +155,36 @@ struct GRPCSinkWithSourceID : public ParentMatch { int64_t source_id_; }; +/* Match an external GRPC (which produces an output table) */ +struct GRPCSinkTypeMatch : public ParentMatch { + explicit GRPCSinkTypeMatch(bool internal) + : ParentMatch(IRNodeType::kGRPCSink), internal_(internal) {} + + bool Match(const IRNode* node) const override { + return GRPCSink().Match(node) && + (internal_ ? static_cast<const GRPCSinkIR*>(node)->has_destination_id() + : static_cast<const GRPCSinkIR*>(node)->has_output_table()); + } + + private: + bool internal_; +}; + +// Matches a GRPC which outputs a final result, streamed to a remote destination. +inline GRPCSinkTypeMatch ExternalGRPCSink() { return GRPCSinkTypeMatch(/* internal */ false); } + +// Matches a GRPC which outputs an intermediate result, streamed to another Carnot instance. +inline GRPCSinkTypeMatch InternalGRPCSink() { return GRPCSinkTypeMatch(/* internal */ true); } + +// Matches a sink that produces a final (rather than intermediate) result. +struct ResultSink : public ParentMatch { + ResultSink() : ParentMatch(IRNodeType::kAny) {} + + bool Match(const IRNode* node) const override { + return ExternalGRPCSink().Match(node) || MemorySink().Match(node); + } +}; + /** * @brief Match a specific integer value. */ diff --git a/src/carnot/planner/rules/rules.cc b/src/carnot/planner/rules/rules.cc index 48b964d483de070677e4331507fe8b3e457798ca..5bb295268e47c708885e7282a4ab5e332e2b075e 100644 --- a/src/carnot/planner/rules/rules.cc +++ b/src/carnot/planner/rules/rules.cc @@ -188,29 +188,38 @@ StatusOr<std::vector<ColumnIR*>> SourceRelationRule::GetColumnsFromRelation( StatusOr<bool> OperatorRelationRule::Apply(IRNode* ir_node) { if (Match(ir_node, UnresolvedReadyOp(BlockingAgg()))) { return SetBlockingAgg(static_cast<BlockingAggIR*>(ir_node)); - } else if (Match(ir_node, UnresolvedReadyOp(Map()))) { + } + if (Match(ir_node, UnresolvedReadyOp(Map()))) { return SetMap(static_cast<MapIR*>(ir_node)); - } else if (Match(ir_node, UnresolvedReadyOp(Union()))) { + } + if (Match(ir_node, UnresolvedReadyOp(Union()))) { return SetUnion(static_cast<UnionIR*>(ir_node)); - } else if (Match(ir_node, UnresolvedReadyOp(Join()))) { + } + if (Match(ir_node, UnresolvedReadyOp(Join()))) { JoinIR* join_node = static_cast<JoinIR*>(ir_node); if (Match(ir_node, UnsetOutputColumnsJoin())) { PL_RETURN_IF_ERROR(SetJoinOutputColumns(join_node)); } return SetOldJoin(join_node); - } else if (Match(ir_node, UnresolvedReadyOp(Drop()))) { + } + if (Match(ir_node, UnresolvedReadyOp(Drop()))) { // Another rule handles this. // TODO(philkuz) unify this rule with the drop to map rule. return false; - } else if (Match(ir_node, UnresolvedReadyOp(MemorySink()))) { + } + if (Match(ir_node, UnresolvedReadyOp(MemorySink()))) { return SetMemorySink(static_cast<MemorySinkIR*>(ir_node)); - } else if (Match(ir_node, UnresolvedReadyOp(Limit())) || - Match(ir_node, UnresolvedReadyOp(Filter())) || - Match(ir_node, UnresolvedReadyOp(GroupBy())) || - Match(ir_node, UnresolvedReadyOp(Rolling()))) { + } + if (Match(ir_node, UnresolvedReadyOp(ExternalGRPCSink()))) { + return SetGRPCSink(static_cast<GRPCSinkIR*>(ir_node)); + } + if (Match(ir_node, UnresolvedReadyOp(Limit())) || Match(ir_node, UnresolvedReadyOp(Filter())) || + Match(ir_node, UnresolvedReadyOp(GroupBy())) || + Match(ir_node, UnresolvedReadyOp(Rolling()))) { // Explicitly match because the general matcher keeps causing problems. return SetOther(static_cast<OperatorIR*>(ir_node)); - } else if (Match(ir_node, UnresolvedReadyOp())) { + } + if (Match(ir_node, UnresolvedReadyOp())) { // Fails in this path because future writers should specify the op. DCHECK(false) << ir_node->DebugString(); return SetOther(static_cast<OperatorIR*>(ir_node)); @@ -475,6 +484,21 @@ StatusOr<bool> OperatorRelationRule::SetMemorySink(MemorySinkIR* sink_ir) const return true; } +StatusOr<bool> OperatorRelationRule::SetGRPCSink(GRPCSinkIR* sink_ir) const { + DCHECK(sink_ir->has_output_table()); + if (!sink_ir->out_columns().size()) { + return SetOther(sink_ir); + } + auto input_relation = sink_ir->parents()[0]->relation(); + Relation output_relation; + for (const auto& col_name : sink_ir->out_columns()) { + output_relation.AddColumn(input_relation.GetColumnType(col_name), col_name, + input_relation.GetColumnDesc(col_name)); + } + PL_RETURN_IF_ERROR(sink_ir->SetRelation(output_relation)); + return true; +} + StatusOr<bool> OperatorRelationRule::SetOther(OperatorIR* operator_ir) const { CHECK_EQ(operator_ir->parents().size(), 1UL); PL_RETURN_IF_ERROR(operator_ir->SetRelation(operator_ir->parents()[0]->relation())); @@ -989,21 +1013,30 @@ StatusOr<bool> RemoveGroupByRule::RemoveGroupBy(GroupByIR* groupby) { return true; } -StatusOr<bool> UniqueSinkNameRule::Apply(IRNode* ir_node) { - if (!Match(ir_node, MemorySink())) { - return false; - } - auto sink = static_cast<MemorySinkIR*>(ir_node); +template <typename TSinkType> +StatusOr<bool> ApplyUniqueSinkName(IRNode* ir_node, + absl::flat_hash_map<std::string, int64_t>* sink_names_count) { + auto sink = static_cast<TSinkType*>(ir_node); bool changed_name = false; - if (sink_names_count_.contains(sink->name())) { - sink->set_name(absl::Substitute("$0_$1", sink->name(), sink_names_count_[sink->name()]++)); + if (sink_names_count->contains(sink->name())) { + sink->set_name(absl::Substitute("$0_$1", sink->name(), (*sink_names_count)[sink->name()]++)); changed_name = true; } else { - sink_names_count_[sink->name()] = 1; + (*sink_names_count)[sink->name()] = 1; } return changed_name; } +StatusOr<bool> UniqueSinkNameRule::Apply(IRNode* ir_node) { + if (Match(ir_node, MemorySink())) { + return ApplyUniqueSinkName<MemorySinkIR>(ir_node, &sink_names_count_); + } + if (Match(ir_node, ExternalGRPCSink())) { + return ApplyUniqueSinkName<GRPCSinkIR>(ir_node, &sink_names_count_); + } + return false; +} + bool ContainsChildColumn(const ExpressionIR& expr, const absl::flat_hash_set<std::string>& colnames) { if (Match(&expr, ColumnNode())) { @@ -1169,7 +1202,7 @@ StatusOr<bool> PruneUnconnectedOperatorsRule::Apply(IRNode* ir_node) { auto ir_graph = ir_node->graph(); auto node_id = ir_node->id(); - if (Match(ir_node, MemorySink()) || sink_connected_nodes_.contains(ir_node)) { + if (Match(ir_node, ResultSink()) || sink_connected_nodes_.contains(ir_node)) { for (int64_t parent_id : ir_graph->dag().ParentsOf(node_id)) { sink_connected_nodes_.insert(ir_graph->Get(parent_id)); } @@ -1201,11 +1234,11 @@ StatusOr<bool> PruneUnconnectedOperatorsRule::Apply(IRNode* ir_node) { return true; } -StatusOr<bool> AddLimitToMemorySinkRule::Apply(IRNode* ir_node) { +StatusOr<bool> AddLimitToBatchResultSinkRule::Apply(IRNode* ir_node) { if (!compiler_state_->has_max_output_rows_per_table()) { return false; } - if (!Match(ir_node, MemorySink())) { + if (!Match(ir_node, ResultSink())) { return false; } auto mem_sink = static_cast<MemorySinkIR*>(ir_node); diff --git a/src/carnot/planner/rules/rules.h b/src/carnot/planner/rules/rules.h index 8d23ed31bc8e01dc9e32d702cfcc30e2ec0b2096..761d135ee6b715950af779ae43a4a826026698a6 100644 --- a/src/carnot/planner/rules/rules.h +++ b/src/carnot/planner/rules/rules.h @@ -181,6 +181,7 @@ class OperatorRelationRule : public Rule { StatusOr<bool> SetUnion(UnionIR* union_ir) const; StatusOr<bool> SetOldJoin(JoinIR* join_op) const; StatusOr<bool> SetMemorySink(MemorySinkIR* map_ir) const; + StatusOr<bool> SetGRPCSink(GRPCSinkIR* map_ir) const; StatusOr<bool> SetRolling(RollingIR* rolling_ir) const; StatusOr<bool> SetOther(OperatorIR* op) const; @@ -526,12 +527,14 @@ class ResolveWindowSizeRollingRule : public Rule { }; /** - * @brief This rule automatically adds a limit to all memory sinks + * @brief This rule automatically adds a limit to all result sinks that are executed in batch. + * Currently, that is all of our queries, but we will need to skip this rule in persistent, + * streaming queries, when they are introduced. * */ -class AddLimitToMemorySinkRule : public Rule { +class AddLimitToBatchResultSinkRule : public Rule { public: - explicit AddLimitToMemorySinkRule(CompilerState* compiler_state) + explicit AddLimitToBatchResultSinkRule(CompilerState* compiler_state) : Rule(compiler_state, /*use_topo*/ false, /*reverse_topological_execution*/ false) {} protected: diff --git a/src/carnot/planner/rules/rules_test.cc b/src/carnot/planner/rules/rules_test.cc index ccd231f15872743b85251b40c9bbd5af9eae2ae2..821462badebab69d2a475e3446797d40cf659c03 100644 --- a/src/carnot/planner/rules/rules_test.cc +++ b/src/carnot/planner/rules/rules_test.cc @@ -729,6 +729,32 @@ TEST_F(OperatorRelationTest, mem_sink_all_columns_test) { EXPECT_EQ(src_relation, sink->relation()); } +TEST_F(OperatorRelationTest, grpc_sink_with_columns_test) { + auto src_relation = MakeRelation(); + MemorySourceIR* src = MakeMemSource(src_relation); + GRPCSinkIR* sink = MakeGRPCSink(src, "foo", {"cpu0"}); + + OperatorRelationRule rule(compiler_state_.get()); + auto result = rule.Execute(graph.get()); + ASSERT_OK(result); + EXPECT_TRUE(result.ValueOrDie()); + + EXPECT_EQ(Relation({types::DataType::FLOAT64}, {"cpu0"}), sink->relation()); +} + +TEST_F(OperatorRelationTest, grpc_sink_all_columns_test) { + auto src_relation = MakeRelation(); + MemorySourceIR* src = MakeMemSource(src_relation); + GRPCSinkIR* sink = MakeGRPCSink(src, "foo", {}); + + OperatorRelationRule rule(compiler_state_.get()); + auto result = rule.Execute(graph.get()); + ASSERT_OK(result); + EXPECT_TRUE(result.ValueOrDie()); + + EXPECT_EQ(src_relation, sink->relation()); +} + TEST_F(OperatorRelationTest, JoinCreateOutputColumns) { std::string join_key = "key"; Relation rel1({types::INT64, types::FLOAT64, types::STRING}, {join_key, "latency", "data"}); @@ -1759,7 +1785,7 @@ TEST_F(RulesTest, UniqueSinkNameRule) { MemorySourceIR* mem_src = MakeMemSource(); MemorySinkIR* foo1 = MakeMemSink(mem_src, "foo"); MemorySinkIR* foo2 = MakeMemSink(mem_src, "foo"); - MemorySinkIR* foo3 = MakeMemSink(mem_src, "foo"); + GRPCSinkIR* foo3 = MakeGRPCSink(mem_src, "foo", std::vector<std::string>{}); MemorySinkIR* bar1 = MakeMemSink(mem_src, "bar"); MemorySinkIR* bar2 = MakeMemSink(mem_src, "bar"); MemorySinkIR* abc = MakeMemSink(mem_src, "abc"); @@ -1770,9 +1796,15 @@ TEST_F(RulesTest, UniqueSinkNameRule) { ASSERT_TRUE(result.ConsumeValueOrDie()); std::vector<std::string> expected_sink_names{"foo", "foo_1", "foo_2", "bar", "bar_1", "abc"}; - std::vector<MemorySinkIR*> sinks{foo1, foo2, foo3, bar1, bar2, abc}; - for (const auto& [idx, sink] : Enumerate(sinks)) { - EXPECT_EQ(sink->name(), expected_sink_names[idx]); + std::vector<OperatorIR*> sinks{foo1, foo2, foo3, bar1, bar2, abc}; + for (const auto& [idx, op] : Enumerate(sinks)) { + std::string sink_name; + if (Match(op, MemorySink())) { + sink_name = static_cast<MemorySinkIR*>(op)->name(); + } else { + sink_name = static_cast<GRPCSinkIR*>(op)->name(); + } + EXPECT_EQ(sink_name, expected_sink_names[idx]); } } @@ -2267,14 +2299,14 @@ TEST_F(RulesTest, PruneUnconnectedOperatorsRule_unchanged) { EXPECT_EQ(nodes_before, graph->dag().TopologicalSort()); } -TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_basic) { +TEST_F(RulesTest, AddLimitToBatchResultSinkRuleTest_basic) { MemorySourceIR* src = MakeMemSource(MakeRelation()); - MemorySinkIR* sink = MakeMemSink(src, "foo", {}); + GRPCSinkIR* sink = MakeGRPCSink(src, "foo", {}); auto compiler_state = std::make_unique<CompilerState>(std::make_unique<RelationMap>(), info_.get(), time_now, 1000); - AddLimitToMemorySinkRule rule(compiler_state.get()); + AddLimitToBatchResultSinkRule rule(compiler_state.get()); auto result = rule.Execute(graph.get()); ASSERT_OK(result); EXPECT_TRUE(result.ValueOrDie()); @@ -2290,7 +2322,7 @@ TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_basic) { EXPECT_THAT(limit->parents(), ElementsAre(src)); } -TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_overwrite_higher) { +TEST_F(RulesTest, AddLimitToBatchResultSinkRuleTest_overwrite_higher) { MemorySourceIR* src = MakeMemSource(MakeRelation()); auto limit = graph->CreateNode<LimitIR>(ast, src, 1001).ValueOrDie(); MakeMemSink(limit, "foo", {}); @@ -2298,7 +2330,7 @@ TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_overwrite_higher) { auto compiler_state = std::make_unique<CompilerState>(std::make_unique<RelationMap>(), info_.get(), time_now, 1000); - AddLimitToMemorySinkRule rule(compiler_state.get()); + AddLimitToBatchResultSinkRule rule(compiler_state.get()); auto result = rule.Execute(graph.get()); ASSERT_OK(result); EXPECT_TRUE(result.ValueOrDie()); @@ -2309,7 +2341,7 @@ TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_overwrite_higher) { EXPECT_EQ(1000, limit->limit_value()); } -TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_dont_overwrite_lower) { +TEST_F(RulesTest, AAddLimitToBatchResultSinkRuleTest_dont_overwrite_lower) { MemorySourceIR* src = MakeMemSource(MakeRelation()); auto limit = graph->CreateNode<LimitIR>(ast, src, 999).ValueOrDie(); MakeMemSink(limit, "foo", {}); @@ -2317,20 +2349,20 @@ TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_dont_overwrite_lower) { auto compiler_state = std::make_unique<CompilerState>(std::make_unique<RelationMap>(), info_.get(), time_now, 1000); - AddLimitToMemorySinkRule rule(compiler_state.get()); + AddLimitToBatchResultSinkRule rule(compiler_state.get()); auto result = rule.Execute(graph.get()); ASSERT_OK(result); EXPECT_FALSE(result.ValueOrDie()); } -TEST_F(RulesTest, AddLimitToMemorySinkRuleTest_skip_if_no_limit) { +TEST_F(RulesTest, AddLimitToBatchResultSinkRuleTest_skip_if_no_limit) { MemorySourceIR* src = MakeMemSource(MakeRelation()); MakeMemSink(src, "foo", {}); auto compiler_state = std::make_unique<CompilerState>(std::make_unique<RelationMap>(), info_.get(), time_now); - AddLimitToMemorySinkRule rule(compiler_state.get()); + AddLimitToBatchResultSinkRule rule(compiler_state.get()); auto result = rule.Execute(graph.get()); ASSERT_OK(result); EXPECT_FALSE(result.ValueOrDie());