Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/google/protobuf/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2152,10 +2152,13 @@ cc_test(
"//src/google/protobuf/io",
"//src/google/protobuf/io:io_win32",
"//src/google/protobuf/stubs",
"//src/google/protobuf/test_protos:test_cc_protos",
"//src/google/protobuf/testing",
"//src/google/protobuf/testing:file",
"//src/google/protobuf/util:differencer",
"@abseil-cpp//absl/base:config",
"@abseil-cpp//absl/flags:flag",
"@abseil-cpp//absl/flags:marshalling",
"@abseil-cpp//absl/hash:hash_testing",
"@abseil-cpp//absl/log:absl_check",
"@abseil-cpp//absl/log:scoped_mock_log",
Expand Down
11 changes: 11 additions & 0 deletions src/google/protobuf/message.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,14 @@ class PROTOBUF_EXPORT Message : public MessageLite {
return GetMetadata().reflection;
}

friend bool AbslParseFlag(absl::string_view text, Message* msg,
std::string* error) {
return msg->AbslParseFlagImpl(text, *error);
}
friend std::string AbslUnparseFlag(const Message& msg) {
return msg.AbslUnparseFlagImpl();
}

protected:
#if !defined(PROTOBUF_CUSTOM_VTABLE)
constexpr Message() {}
Expand All @@ -401,6 +409,9 @@ class PROTOBUF_EXPORT Message : public MessageLite {
// For CODE_SIZE types
static bool IsInitializedImpl(const MessageLite&);

bool AbslParseFlagImpl(absl::string_view text, std::string& error);
std::string AbslUnparseFlagImpl() const;

size_t ComputeUnknownFieldsSize(
size_t total_size, const internal::CachedSize* cached_size) const;
size_t MaybeComputeUnknownFieldsSize(
Expand Down
6 changes: 6 additions & 0 deletions src/google/protobuf/message_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@
#include <memory>
#include <string>
#include <tuple>
#include <variant>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/flags/flag.h"
#include "absl/flags/marshalling.h"
#include "absl/hash/hash_testing.h"
#include "absl/log/absl_check.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/dynamic_message.h"
Expand All @@ -32,6 +37,7 @@
#include "google/protobuf/internal_visibility.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/port.h"
#include "test_protos/abseil_flag_test.pb.h"
#include "google/protobuf/unittest.pb.h"
#include "google/protobuf/unittest_import.pb.h"
#include "google/protobuf/unittest_lite.pb.h"
Expand Down
18 changes: 18 additions & 0 deletions src/google/protobuf/test_protos/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("//bazel:cc_proto_library.bzl", "cc_proto_library")
load("//bazel:proto_library.bzl", "proto_library")

package(
default_testonly = 1,
default_visibility = ["//src/google/protobuf:__subpackages__"],
)

proto_library(
name = "test_protos",
srcs = glob(["*.proto"]),
strip_import_prefix = "/src",
)

cc_proto_library(
name = "test_cc_protos",
deps = [":test_protos"],
)
23 changes: 23 additions & 0 deletions src/google/protobuf/test_protos/abseil_flag_test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
edition = "2024";

package proto2_unittest;

enum OpenEnumForFlagTest {
OPEN_ENUM_FOR_FLAG_TEST_DEFAULT = 0;
}

message AbseilFlagTestProto {
int32 i = 1;
OpenEnumForFlagTest e = 100;

extensions 1000 to 10000;
extend AbseilFlagTestProto {
int32 ext = 1000;
}
}

message MessageWithConflictingFlagPrefixes {
int32 text = 1;
int32 base64text = 2;
int32 base64serialized = 3;
}
204 changes: 204 additions & 0 deletions src/google/protobuf/text_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <limits>
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/macros.h"
#include "absl/container/btree_set.h"
#include "absl/log/absl_check.h"
#include "absl/log/absl_log.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/cord.h"
Expand All @@ -38,7 +42,9 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "google/protobuf/any.h"
Expand All @@ -53,7 +59,9 @@
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/port.h"
#include "google/protobuf/reflection_mode.h"
#include "google/protobuf/reflection_visit_fields.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/unknown_field_set.h"
#include "google/protobuf/wire_format_lite.h"
Expand Down Expand Up @@ -174,6 +182,202 @@ PROTOBUF_EXPORT std::string Utf8Format(const Message& message) {
}


namespace {

enum class AbslFlagFormat {
kTextFormat,
kSerialized,
};

struct AbslFlagHeader {
AbslFlagFormat format;
absl::string_view format_name;
std::vector<absl::string_view> options;
bool uses_dead_char = false;
bool uses_prefix = false;
};

std::variant<AbslFlagHeader, std::string> ConsumeAbslFlagHeader(
absl::string_view& text) {
AbslFlagHeader header;

if (text.empty()) {
// Whatever format is fine.
header.format = AbslFlagFormat::kTextFormat;
return header;
}

if (absl::ConsumePrefix(&text, ":")) {
header.uses_dead_char = true;
}

auto pos = text.find(':');
if (pos == text.npos) {
header.format = AbslFlagFormat::kTextFormat;
return header;
}

header.uses_prefix = true;

absl::string_view format_spec = text.substr(0, pos);
if (!header.uses_dead_char) {
header.format_name = format_spec;
// Legacy specs.
if (format_spec == "text") {
header.format = AbslFlagFormat::kTextFormat;
} else if (format_spec == "base64serialized") {
header.format = AbslFlagFormat::kSerialized;
header.options = {"base64"};
} else {
if (absl::StrContains(format_spec, ",")) {
return absl::StrFormat(
"Format options are only allowed with delimited format specifier. "
"Use `:%1$s:` instead of `%1$s:`",
format_spec);
}
header.uses_prefix = false;
header.format = AbslFlagFormat::kTextFormat;
return header;
}
} else {
std::vector<absl::string_view> parts = absl::StrSplit(format_spec, ',');
header.format_name = parts[0];

if (header.format_name == "text") {
header.format = AbslFlagFormat::kTextFormat;
} else if (header.format_name == "serialized") {
header.format = AbslFlagFormat::kSerialized;
} else {
return absl::StrFormat("Invalid format `%s`.", header.format_name);
}

header.options.assign(parts.begin() + 1, parts.end());
}

if (header.uses_prefix) {
text.remove_prefix(pos + 1);
}
return header;
}

} // namespace

bool Message::AbslParseFlagImpl(absl::string_view text, std::string& error) {
Clear();

auto header_or_error = ConsumeAbslFlagHeader(text);
if (std::holds_alternative<std::string>(header_or_error)) {
error = std::get<std::string>(header_or_error);
return false;
}
auto header = std::get<AbslFlagHeader>(std::move(header_or_error));

if (!header.uses_dead_char) {
error = "Prefix must start with a `:`. Eg `:text:`.";
return false;
}

// If we have a prefix without a dead char, verify that the message does not
// have a field by that name as that would be ambiguous.
if (!header.uses_dead_char && header.uses_prefix &&
GetDescriptor()->FindFieldByName(header.format_name) != nullptr) {
error = absl::StrFormat(
"Prefix `%s:` used is ambiguous with message fields. If you meant to "
"use this prefix, use `:%s:` instead. If you meant to use text "
"format, use `:text:` as a prefix.",
header.format_name, header.format_name);
return false;
}

const auto verify_options =
[&](std::initializer_list<absl::string_view> valid_options) -> bool {
for (absl::string_view o : header.options) {
if (!absl::c_linear_search(valid_options, o)) {
error = absl::StrFormat("Unknown option `%s` for format `%s`.", o,
header.format_name);
return false;
}
}
return true;
};

static constexpr absl::string_view kBase64 = "base64";

std::string unescaped;
const auto unescape_if_needed = [&] {
if (absl::c_linear_search(header.options, kBase64)) {
if (!absl::Base64Unescape(text, &unescaped)) {
error = absl::StrFormat("Invalid base64 input.");
return false;
}
text = unescaped;
}
return true;
};

switch (header.format) {
case AbslFlagFormat::kTextFormat: {
static constexpr absl::string_view kIgnoreUnknown = "ignore_unknown";
if (!verify_options({kIgnoreUnknown, kBase64})) return false;
if (!unescape_if_needed()) return false;
TextFormat::Parser parser;
struct StringErrorCollector : io::ErrorCollector {
explicit StringErrorCollector(std::string& error) : error(error) {}
std::string& error;
void RecordError(int line, io::ColumnNumber column,
absl::string_view message) override {
error = absl::StrFormat("(Line %v, Column %v): %v", line, column,
message);
}
} collector(error);
if (absl::c_linear_search(header.options, kIgnoreUnknown)) {
parser.AllowUnknownField(true);
parser.AllowUnknownExtension(true);
}
parser.RecordErrorsTo(&collector);
return parser.ParseFromString(text, this);
}

case AbslFlagFormat::kSerialized: {
if (!verify_options({kBase64})) return false;
if (!unescape_if_needed()) return false;
return ParseFromString(text);
}

default:
internal::Unreachable();
}
}

std::string Message::AbslUnparseFlagImpl() const {
bool has_ufs = !GetReflection()->GetUnknownFields(*this).empty();
internal::VisitMessageFields(*this, [&](const auto& msg) {
has_ufs = has_ufs || !msg.GetReflection()->GetUnknownFields(msg).empty();
});

if (has_ufs) {
// We can't use text format because it won't round trip
// Use binary instead.
return absl::StrCat(":serialized,base64:",
absl::Base64Escape(SerializeAsString()));
} else {
TextFormat::Printer printer;
printer.SetSingleLineMode(true);
printer.SetUseShortRepeatedPrimitives(true);
std::string str;
// PrintToString can't really fail.
(void)printer.PrintToString(*this, &str);

// If completely empty, just return the empty string.
// It is usually the default and nicer to read.
if (str.empty()) {
return str;
}

return absl::StrCat(":text:", str);
}
}

// ===========================================================================
// Implementation of the parse information tree class.
void TextFormat::ParseInfoTree::RecordLocation(
Expand Down
Loading