Skip to content

Performance: Implement non-wildcard topic subscription management in MqttClientSessionsManager. #2175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ namespace MQTTnet.Server.Internal
{
public interface ISubscriptionChangedNotification
{
void OnSubscriptionsAdded(MqttSession clientSession, List<string> subscriptionsTopics);
void OnSubscriptionsAdded(MqttSession clientSession, List<MqttSubscription> subscriptionsTopics);

void OnSubscriptionsRemoved(MqttSession clientSession, List<string> subscriptionTopics);
}
Expand Down
91 changes: 74 additions & 17 deletions Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public sealed class MqttClientSessionsManager : ISubscriptionChangedNotification
// The _sessions dictionary contains all session, the _subscriberSessions hash set contains subscriber sessions only.
// See the MqttSubscription object for a detailed explanation.
readonly MqttSessionsStorage _sessionsStorage = new();
readonly HashSet<MqttSession> _subscriberSessions = [];
readonly HashSet<MqttSession> _subscriberSessionsWithWildcards = [];
readonly Dictionary<string, HashSet<MqttSession>> _simpleTopicToSessions = [];

public MqttClientSessionsManager(MqttServerOptions options, MqttRetainedMessagesManager retainedMessagesManager, MqttServerEventContainer eventContainer, IMqttNetLogger logger)
{
Expand Down Expand Up @@ -77,7 +78,7 @@ public async Task DeleteSessionAsync(string clientId)
{
if (_sessionsStorage.TryRemoveSession(clientId, out session))
{
_subscriberSessions.Remove(session);
CleanupClientSessionUnsafe(session);
}
}
finally
Expand Down Expand Up @@ -161,11 +162,30 @@ public async Task<DispatchApplicationMessageResult> DispatchApplicationMessage(
await _retainedMessagesManager.UpdateMessage(senderId, applicationMessage).ConfigureAwait(false);
}

List<MqttSession> subscriberSessions;
HashSet<MqttSession> subscriberSessions;
_sessionsManagementLock.EnterReadLock();
try
{
subscriberSessions = _subscriberSessions.ToList();
if (_simpleTopicToSessions.TryGetValue(applicationMessage.Topic, out var matchedSimpleTopicSessions))
{
// Create the initial subscriberSessions from whichever set is larger to take advantage
// of the internal ConstructFrom other HashSet optimizations
if (matchedSimpleTopicSessions.Count > _subscriberSessionsWithWildcards.Count)
{
subscriberSessions = new HashSet<MqttSession>(matchedSimpleTopicSessions);
subscriberSessions.UnionWith(_subscriberSessionsWithWildcards);
}
else
{
subscriberSessions = new HashSet<MqttSession>(_subscriberSessionsWithWildcards);
subscriberSessions.UnionWith(matchedSimpleTopicSessions);
}
}
else
{
// Always include the sessions with wildcards. They need to be properly matched against the topic filter.
subscriberSessions = new HashSet<MqttSession>(_subscriberSessionsWithWildcards);
}
}
finally
{
Expand Down Expand Up @@ -446,20 +466,32 @@ public async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter
}
}

public void OnSubscriptionsAdded(MqttSession clientSession, List<string> topics)
public void OnSubscriptionsAdded(MqttSession clientSession, List<MqttSubscription> subscriptions)
{
_sessionsManagementLock.EnterWriteLock();
try
{
if (!clientSession.HasSubscribedTopics)
foreach (var subscription in subscriptions)
{
// first subscribed topic
_subscriberSessions.Add(clientSession);
}

foreach (var topic in topics)
{
clientSession.AddSubscribedTopic(topic);
if (subscription.TopicHasWildcard)
{
if (!clientSession.HasSubscribedWildcardTopics)
{
_subscriberSessionsWithWildcards.Add(clientSession);
}
}
else
{
if (_simpleTopicToSessions.TryGetValue(subscription.Topic, out var simpleTopicSessions))
{
simpleTopicSessions.Add(clientSession);
}
else
{
_simpleTopicToSessions[subscription.Topic] = [clientSession];
}
}
clientSession.AddSubscribedTopic(subscription.Topic, subscription.TopicHasWildcard);
}
}
finally
Expand All @@ -475,13 +507,21 @@ public void OnSubscriptionsRemoved(MqttSession clientSession, List<string> subsc
{
foreach (var subscriptionTopic in subscriptionTopics)
{
if (_simpleTopicToSessions.TryGetValue(subscriptionTopic, out var simpleTopicSessions))
{
simpleTopicSessions.Remove(clientSession);
if (simpleTopicSessions.Count == 0)
{
_simpleTopicToSessions.Remove(subscriptionTopic);
}
}
clientSession.RemoveSubscribedTopic(subscriptionTopic);
}

if (!clientSession.HasSubscribedTopics)
if (!clientSession.HasSubscribedWildcardTopics)
{
// last subscription removed
_subscriberSessions.Remove(clientSession);
// Last wildcard subscription removed
_subscriberSessionsWithWildcards.Remove(clientSession);
}
}
finally
Expand Down Expand Up @@ -564,7 +604,7 @@ async Task<MqttConnectedClient> CreateClientConnection(
if (connectPacket.CleanSession)
{
_logger.Verbose("Deleting existing session of client '{0}' due to clean start", connectPacket.ClientId);
_subscriberSessions.Remove(oldSession);
CleanupClientSessionUnsafe(oldSession);
session = CreateSession(connectPacket, validatingConnectionEventArgs);
}
else
Expand Down Expand Up @@ -669,6 +709,23 @@ MqttSession GetClientSession(string clientId)
}
}

//* Must be called with the _sessionsManagementLock held.
void CleanupClientSessionUnsafe(MqttSession session)
{
_subscriberSessionsWithWildcards.Remove(session);
foreach (var simpleTopic in session.SubscribedSimpleTopics)
{
if (_simpleTopicToSessions.TryGetValue(simpleTopic, out var simpleTopicSessions))
{
simpleTopicSessions.Remove(session);
if (simpleTopicSessions.Count == 0)
{
_simpleTopicToSessions.Remove(simpleTopic);
}
}
}
}

async Task<MqttConnectPacket> ReceiveConnectPacket(IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{
try
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public async Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket
var retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false);
var result = new SubscribeResult(subscribePacket.TopicFilters.Count);

var addedSubscriptions = new List<string>();
var addedSubscriptions = new List<MqttSubscription>();
var finalTopicFilters = new List<MqttTopicFilter>();

// The topic filters are order by its QoS so that the higher QoS will win over a
Expand Down Expand Up @@ -195,7 +195,7 @@ public async Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket

var createSubscriptionResult = CreateSubscription(topicFilter, subscribePacket.SubscriptionIdentifier, interceptorEventArgs.Response.ReasonCode);

addedSubscriptions.Add(topicFilter.Topic);
addedSubscriptions.Add(createSubscriptionResult.Subscription);
finalTopicFilters.Add(topicFilter);

FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result);
Expand Down
23 changes: 14 additions & 9 deletions Source/MQTTnet.Server/Internal/MqttSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ public sealed class MqttSession : IDisposable
// Do not use a dictionary in order to keep the ordering of the messages.
readonly List<MqttPublishPacket> _unacknowledgedPublishPackets = new();

// Bookkeeping to know if this is a subscribing client; lazy initialize later.
HashSet<string> _subscribedTopics;
readonly HashSet<string> _subscribedSimpleTopics = [];
readonly HashSet<string> _subscribedWildcardTopics = [];

public MqttSession(
MqttConnectPacket connectPacket,
Expand All @@ -50,7 +50,9 @@ public MqttSession(

public uint ExpiryInterval => _connectPacket.SessionExpiryInterval;

public bool HasSubscribedTopics => _subscribedTopics != null && _subscribedTopics.Count > 0;
public bool HasSubscribedWildcardTopics => _subscribedWildcardTopics.Count > 0;

public HashSet<string> SubscribedSimpleTopics => _subscribedSimpleTopics;

public string Id => _connectPacket.ClientId;

Expand Down Expand Up @@ -79,14 +81,16 @@ public MqttPublishPacket AcknowledgePublishPacket(ushort packetIdentifier)
return publishPacket;
}

public void AddSubscribedTopic(string topic)
public void AddSubscribedTopic(string topic, bool isWildcardTopic)
{
if (_subscribedTopics == null)
if (isWildcardTopic)
{
_subscribedTopics = new HashSet<string>();
_subscribedWildcardTopics.Add(topic);
}
else
{
_subscribedSimpleTopics.Add(topic);
}

_subscribedTopics.Add(topic);
}

public Task DeleteAsync()
Expand Down Expand Up @@ -208,7 +212,8 @@ public void Recover()

public void RemoveSubscribedTopic(string topic)
{
_subscribedTopics?.Remove(topic);
_subscribedSimpleTopics.Remove(topic);
_subscribedWildcardTopics.Remove(topic);
}

public Task<SubscribeResult> Subscribe(MqttSubscribePacket subscribePacket, CancellationToken cancellationToken)
Expand Down
1 change: 0 additions & 1 deletion Source/MQTTnet.Tests/TopicFilterComparer_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Server;
using MQTTnet.Server.Internal;

namespace MQTTnet.Tests
Expand Down