Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ public override async Task SendMessageAsync(
messageId = messageWithId.Id.ToString();
}

LogTransportSendingMessageSensitive(message);

using var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, _messageEndpoint);
StreamableHttpClientSessionTransport.CopyAdditionalHeaders(httpRequestMessage.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null);
var response = await _httpClient.SendAsync(httpRequestMessage, message, cancellationToken).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);

LogTransportSendingMessageSensitive(Name, json);

using var _ = await _sendLock.LockAsync(cancellationToken).ConfigureAwait(false);
try
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ internal async Task<HttpResponseMessage> SendHttpRequestAsync(JsonRpcMessage mes
$"Call {nameof(McpClient)}.{nameof(McpClient.ResumeSessionAsync)} to resume existing sessions.");
}

LogTransportSendingMessageSensitive(message);

using var sendCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _connectionCts.Token);
cancellationToken = sendCts.Token;

Expand Down
16 changes: 16 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/TransportBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using System.Diagnostics;
using System.Text.Json;
using System.Threading.Channels;

namespace ModelContextProtocol.Protocol;
Expand Down Expand Up @@ -166,6 +167,21 @@ protected void SetDisconnected(Exception? error = null)
[LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} transport send failed for message ID '{MessageId}'.")]
private protected partial void LogTransportSendFailed(string endpointName, string messageId, Exception exception);

[LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} transport sending message. Message: '{Message}'.")]
private protected partial void LogTransportSendingMessageSensitive(string endpointName, string message);

/// <summary>
/// Logs a sending message at Trace level if trace logging is enabled.
/// </summary>
/// <param name="message">The JSON-RPC message to log.</param>
private protected void LogTransportSendingMessageSensitive(JsonRpcMessage message)
{
if (_logger.IsEnabled(LogLevel.Trace))
{
LogTransportSendingMessageSensitive(Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage));
}
}

[LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} transport reading messages.")]
private protected partial void LogTransportEnteringReadMessagesLoop(string endpointName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation

try
{
await JsonSerializer.SerializeAsync(_outputStream, message, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), cancellationToken).ConfigureAwait(false);
var json = JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage);
LogTransportSendingMessageSensitive(Name, json);
await _outputStream.WriteAsync(Encoding.UTF8.GetBytes(json), cancellationToken).ConfigureAwait(false);
await _outputStream.WriteAsync(s_newlineBytes, cancellationToken).ConfigureAwait(false);
await _outputStream.FlushAsync(cancellationToken).ConfigureAwait(false);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelContextProtocol.Protocol;
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;
Expand All @@ -21,6 +22,14 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper)
InitializationTimeout = TimeSpan.FromSeconds(10),
ServerInstructions = "Test Instructions"
};

// Override the LoggerFactory to use Trace level for testing Trace-level logging
LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder =>
{
builder.AddProvider(XunitLoggerProvider);
builder.AddProvider(MockLoggerProvider);
builder.SetMinimumLevel(LogLevel.Trace);
});
}

[Fact(Skip="https://github.com/modelcontextprotocol/csharp-sdk/issues/143")]
Expand Down Expand Up @@ -193,4 +202,59 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters()
Assert.True(magnifyingGlassFound, "Magnifying glass emoji not found in result");
Assert.True(rocketFound, "Rocket emoji not found in result");
}

[Fact]
public async Task SendMessageAsync_Should_Log_At_Trace_Level()
{
// Arrange
using var output = new MemoryStream();

await using var transport = new StreamServerTransport(
new Pipe().Reader.AsStream(),
output,
loggerFactory: LoggerFactory);

// Act
var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) };
await transport.SendMessageAsync(message, TestContext.Current.CancellationToken);

// Assert
var traceLogMessages = MockLoggerProvider.LogMessages
.Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport sending message"))
.ToList();

Assert.NotEmpty(traceLogMessages);
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":44"));
}

[Fact]
public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level()
{
// Arrange
var message = new JsonRpcRequest { Method = "test", Id = new RequestId(99) };
var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions);

Pipe pipe = new();
using var input = pipe.Reader.AsStream();

await using var transport = new StreamServerTransport(
input,
Stream.Null,
loggerFactory: LoggerFactory);

// Act
await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\n"), TestContext.Current.CancellationToken);

// Wait for the message to be processed
var canRead = await transport.MessageReader.WaitToReadAsync(TestContext.Current.CancellationToken);
Assert.True(canRead, "Nothing to read here from transport message reader");

// Assert
var traceLogMessages = MockLoggerProvider.LogMessages
.Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport received message"))
.ToList();

Assert.NotEmpty(traceLogMessages);
Assert.Contains(traceLogMessages, x => x.Message.Contains("\"method\":\"test\"") && x.Message.Contains("\"id\":99"));
}
}
Loading