Skip to content
Open
62 changes: 58 additions & 4 deletions cpp/src/gandiva/precompiled/string_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,14 @@ const char* replace_with_max_len_utf8_utf8_utf8(gdv_int64 context, const char* t

for (; text_index <= text_len - from_str_len;) {
if (memcmp(text + text_index, from_str, from_str_len) == 0) {
if (out_index + text_index - last_match_index + to_str_len > max_length) {
// Compute the prospective length in gdv_int64: now that the wrapper may
// pass a max_length near INT_MAX, out_index can approach INT_MAX and a
// 32-bit sum would overflow before this guard runs -- precisely the case
// the guard exists to catch. (text_index - last_match_index) is a bounded
// non-negative span.
gdv_int64 prospective_len = static_cast<gdv_int64>(out_index) +
(text_index - last_match_index) + to_str_len;
if (prospective_len > max_length) {
gdv_fn_context_set_error_msg(context,
"REPLACE: Buffer overflow for output string");
*out_len = 0;
Expand Down Expand Up @@ -1932,7 +1939,8 @@ const char* replace_with_max_len_utf8_utf8_utf8(gdv_int64 context, const char* t
return text;
}

if (out_index + text_len - last_match_index > max_length) {
gdv_int64 final_len = static_cast<gdv_int64>(out_index) + (text_len - last_match_index);
if (final_len > max_length) {
gdv_fn_context_set_error_msg(context, "REPLACE: Buffer overflow for output string");
*out_len = 0;
return "";
Expand All @@ -1948,9 +1956,55 @@ const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text,
gdv_int32 text_len, const char* from_str,
gdv_int32 from_str_len, const char* to_str,
gdv_int32 to_str_len, gdv_int32* out_len) {
// Size the output buffer so large results are not capped by an arbitrary
// limit, while avoiding a second pass over the input in the common case.
// - No replacement possible, or the result can only shrink/stay equal:
// text_len is a safe exact-or-upper bound, no scan.
// - Small expansion (replacement at most ~2x the match): use an O(1) upper
// bound that assumes every position matches. This over-allocates by at
// most ~text_len bytes but skips the match-counting scan entirely.
// - Large expansion: that upper bound could be many times the input for
// sparse matches, so count non-overlapping matches for the exact size.
gdv_int64 max_length;
if (from_str_len <= 0 || from_str_len > text_len || to_str_len <= from_str_len) {
max_length = text_len;
} else {
gdv_int32 delta = to_str_len - from_str_len; // > 0
gdv_int64 upper_bound = static_cast<gdv_int64>(text_len) +
(static_cast<gdv_int64>(text_len) / from_str_len) * delta;
if (delta <= from_str_len && upper_bound <= INT_MAX) {
max_length = upper_bound;
} else {
gdv_int64 num_matches = 0;
for (gdv_int32 i = 0; i <= text_len - from_str_len;) {
if (memcmp(text + i, from_str, from_str_len) == 0) {
num_matches++;
i += from_str_len;
} else {
i++;
}
}
// No matches: the result is the input unchanged; return it without calling
// the helper (which would otherwise scan the text a second time).
if (num_matches == 0) {
*out_len = text_len;
return text;
}
max_length = static_cast<gdv_int64>(text_len) + num_matches * delta;
}
}
// Gandiva variable-length output uses int32 offsets, so a single output string
// cannot exceed INT_MAX bytes. Report this explicitly instead of letting the
// cast below wrap silently.
if (max_length > INT_MAX) {
gdv_fn_context_set_error_msg(context,
"REPLACE: output string exceeds maximum size of 2GB");
*out_len = 0;
return "";
}
return replace_with_max_len_utf8_utf8_utf8(context, text, text_len, from_str,
from_str_len, to_str, to_str_len, 65535,
out_len);
from_str_len, to_str, to_str_len,
static_cast<gdv_int32>(max_length), out_len);
Comment thread
lriggs marked this conversation as resolved.
}

// Returns the quoted string (Includes escape character for any single quotes)
Expand Down
62 changes: 62 additions & 0 deletions cpp/src/gandiva/precompiled/string_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,68 @@ TEST(TestStringOps, TestReplace) {
EXPECT_EQ(std::string(out_str, out_len), "TestString");
EXPECT_FALSE(ctx.has_error());

// No match on the large-expansion (counting) path: from "z" to "zzz" expands
// by more than from_len, so this exercises the count branch's zero-match
// early return.
out_str = replace_utf8_utf8_utf8(ctx_ptr, "TestString", 10, "z", 1, "zzz", 3, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "TestString");
EXPECT_FALSE(ctx.has_error());

// Large output (>64 KB) must not overflow: buffer is sized to the exact result.
Comment thread
lriggs marked this conversation as resolved.
std::string large_in(35000, 'X');
std::string large_expected(70000, '\0');
for (int i = 0; i < 35000; ++i) {
large_expected[2 * i] = 'X';
large_expected[2 * i + 1] = 'Y';
}
out_str = replace_utf8_utf8_utf8(ctx_ptr, large_in.data(),
static_cast<int32_t>(large_in.size()), "X", 1, "XY", 2,
&out_len);
EXPECT_EQ(out_len, 70000);
EXPECT_EQ(std::string(out_str, out_len), large_expected);
EXPECT_FALSE(ctx.has_error());

// Large shrinking output ("XX" -> "X") on a >64 KB input.
std::string large_shrink_in(70000, 'X');
std::string large_shrink_expected(35000, 'X');
out_str = replace_utf8_utf8_utf8(ctx_ptr, large_shrink_in.data(),
static_cast<int32_t>(large_shrink_in.size()), "XX", 2,
"X", 1, &out_len);
EXPECT_EQ(out_len, 35000);
EXPECT_EQ(std::string(out_str, out_len), large_shrink_expected);
EXPECT_FALSE(ctx.has_error());

// Edge case: result size of exactly 0 (every byte of text is removed). Takes
// the no-scan shrink path (to_str_len <= from_str_len).
out_str = replace_utf8_utf8_utf8(ctx_ptr, "aaaa", 4, "a", 1, "", 0, &out_len);
EXPECT_EQ(out_len, 0);
EXPECT_EQ(std::string(out_str, out_len), "");
EXPECT_FALSE(ctx.has_error());

// Edge case: result size one past the INT_MAX boundary. 65536 single-char
// matches each expanding to 32768 bytes gives max_length = 65536 * 32768 =
// 2^31 = INT_MAX + 1, so it is reported cleanly (guard fires before any alloc).
std::string boundary_in(65536, 'a');
std::string boundary_to(32768, 'b');
replace_utf8_utf8_utf8(
ctx_ptr, boundary_in.data(), static_cast<int32_t>(boundary_in.size()), "a", 1,
boundary_to.data(), static_cast<int32_t>(boundary_to.size()), &out_len);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("exceeds maximum size"));
EXPECT_EQ(out_len, 0);
ctx.Reset();

// Output that would exceed INT_MAX (2GB) is reported cleanly rather than
// silently wrapping the int32 size. 50000 matches each expanding to 50000
// bytes implies max_length = 2.5e9; the guard fires before any large alloc.
std::string huge_in(50000, 'X');
std::string huge_to(50000, 'Z');
replace_utf8_utf8_utf8(ctx_ptr, huge_in.data(), static_cast<int32_t>(huge_in.size()),
"X", 1, huge_to.data(), static_cast<int32_t>(huge_to.size()),
&out_len);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("exceeds maximum size"));
EXPECT_EQ(out_len, 0);
ctx.Reset();

replace_with_max_len_utf8_utf8_utf8(ctx_ptr, "Hell", 4, "ell", 3, "ollow", 5, 5,
&out_len);
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Buffer overflow for output string"));
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/gandiva/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ if(ARROW_BUILD_STATIC)
"gandiva"
EXTRA_LINK_LIBS
gandiva_static)

# Calls the precompiled REPLACE functions directly, so it compiles
# string_ops.cc/context_helper.cc with GANDIVA_UNIT_TEST=1 (which exposes them
# as linkable symbols). Only built when ARROW_BUILD_BENCHMARKS is ON.
add_arrow_benchmark(string_ops_benchmark
SOURCES
string_ops_benchmark.cc
../precompiled/string_ops.cc
../context_helper.cc
PREFIX
"gandiva"
EXTRA_LINK_LIBS
gandiva_static)
if(TARGET gandiva-string-ops-benchmark)
target_compile_definitions(gandiva-string-ops-benchmark
PRIVATE GANDIVA_UNIT_TEST=1 ARROW_STATIC GANDIVA_STATIC)
target_include_directories(gandiva-string-ops-benchmark SYSTEM
PRIVATE ${CMAKE_SOURCE_DIR}/src)
endif()
endif()

add_subdirectory(external_functions)
162 changes: 162 additions & 0 deletions cpp/src/gandiva/tests/string_ops_benchmark.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Microbenchmark comparing the current REPLACE implementation against the
// pre-change one, to measure the cost of the match-counting scan the fix added
// to size the output buffer.
//
// BM_ReplaceNew = replace_utf8_utf8_utf8 (upper bound or counting scan, then a
// single write pass)
// BM_ReplaceOld = replace_with_max_len_utf8_utf8_utf8(..., capacity, ...) given
// an exact buffer: the pre-change algorithm with no counting
// scan. Compare the two rows per case to read the scan's cost.
//
// Unlike the projector-level micro_benchmarks, this calls the precompiled
// functions directly, so the build compiles string_ops.cc with GANDIVA_UNIT_TEST.

#include <cstdint>
#include <cstring>
#include <string>
#include <vector>

#include "benchmark/benchmark.h"

#include "gandiva/execution_context.h"
#include "gandiva/precompiled/types.h"

namespace gandiva {
namespace {

struct ReplaceCase {
const char* name;
int64_t text_len;
int stride; // a match (the first byte of `from`) every `stride` bytes
const char* from;
const char* to;
};

const std::vector<ReplaceCase>& Cases() {
static const std::vector<ReplaceCase> cases = {
// Small expansion (to_len - from_len <= from_len): no scan, upper bound.
{"small/dense expand a->ab", 256, 1, "a", "ab"},
{"small/sparse expand a->ab", 256, 64, "a", "ab"},
{"medium/dense expand a->ab", 64 * 1024, 1, "a", "ab"},
{"medium/sparse expand a->ab", 64 * 1024, 64, "a", "ab"},
{"large/dense expand a->ab", 4 * 1024 * 1024, 1, "a", "ab"},
{"large/sparse expand a->ab", 4 * 1024 * 1024, 64, "a", "ab"},
// Big expansion (to_len - from_len > from_len): falls back to the scan.
{"large/dense bigexp a->abcd", 4 * 1024 * 1024, 1, "a", "abcd"},
{"large/sparse bigexp a->abcd", 4 * 1024 * 1024, 64, "a", "abcd"},
// Shrink (to_len <= from_len): no scan.
{"large/dense shrink ab->a", 4 * 1024 * 1024, 2, "ab", "a"},
};
return cases;
}

// Builds a `len`-byte string with `match` once every `stride` bytes.
std::string MakeText(int64_t len, int stride, char match, char filler) {
std::string s(static_cast<size_t>(len), filler);
for (int64_t i = 0; i < len; i += stride) {
s[static_cast<size_t>(i)] = match;
}
return s;
}

// Exact output size, so the "old" arm gets a buffer large enough to complete.
int32_t ExactCapacity(const std::string& text, const char* from, int flen, int olen) {
int64_t matches = 0;
auto tlen = static_cast<int32_t>(text.size());
if (flen > 0 && flen <= tlen) {
for (int32_t i = 0; i <= tlen - flen;) {
if (memcmp(text.data() + i, from, flen) == 0) {
++matches;
i += flen;
} else {
++i;
}
}
}
return static_cast<int32_t>(tlen + matches * (olen - flen));
}

void BM_ReplaceNew(benchmark::State& state) {
const ReplaceCase& c = Cases()[state.range(0)];
auto flen = static_cast<int>(strlen(c.from));
auto olen = static_cast<int>(strlen(c.to));
std::string text = MakeText(c.text_len, c.stride, c.from[0], 'x');
auto tlen = static_cast<int32_t>(text.size());
ExecutionContext ctx;
auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);

// One warm-up call doubling as a correctness guard.
int32_t out_len = 0;
replace_utf8_utf8_utf8(ctx_ptr, text.data(), tlen, c.from, flen, c.to, olen, &out_len);
if (ctx.has_error()) {
state.SkipWithError(ctx.get_error().c_str());
return;
}

for (auto _ : state) {
ctx.Reset();
const char* out = replace_utf8_utf8_utf8(ctx_ptr, text.data(), tlen, c.from, flen,
c.to, olen, &out_len);
benchmark::DoNotOptimize(out);
benchmark::DoNotOptimize(out_len);
}
state.SetBytesProcessed(state.iterations() * tlen);
state.SetLabel(c.name);
}

void BM_ReplaceOld(benchmark::State& state) {
const ReplaceCase& c = Cases()[state.range(0)];
auto flen = static_cast<int>(strlen(c.from));
auto olen = static_cast<int>(strlen(c.to));
std::string text = MakeText(c.text_len, c.stride, c.from[0], 'x');
auto tlen = static_cast<int32_t>(text.size());
int32_t capacity = ExactCapacity(text, c.from, flen, olen);
ExecutionContext ctx;
auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);

int32_t out_len = 0;
replace_with_max_len_utf8_utf8_utf8(ctx_ptr, text.data(), tlen, c.from, flen, c.to,
olen, capacity, &out_len);
if (ctx.has_error()) {
state.SkipWithError(ctx.get_error().c_str());
return;
}

for (auto _ : state) {
ctx.Reset();
const char* out = replace_with_max_len_utf8_utf8_utf8(
ctx_ptr, text.data(), tlen, c.from, flen, c.to, olen, capacity, &out_len);
benchmark::DoNotOptimize(out);
benchmark::DoNotOptimize(out_len);
}
state.SetBytesProcessed(state.iterations() * tlen);
state.SetLabel(c.name);
}

} // namespace

BENCHMARK(BM_ReplaceNew)
->DenseRange(0, static_cast<int64_t>(Cases().size()) - 1)
->Unit(benchmark::kMicrosecond);
BENCHMARK(BM_ReplaceOld)
->DenseRange(0, static_cast<int64_t>(Cases().size()) - 1)
->Unit(benchmark::kMicrosecond);

} // namespace gandiva
Loading