// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.CompilerServices;
using CommunityToolkit.Mvvm.Messaging.Internals;
namespace CommunityToolkit.Mvvm.Messaging;
///
/// Extensions for the type.
///
public static partial class IMessengerExtensions
{
///
/// A class that acts as a container to load the instance linked to
/// the method.
/// This class is needed to avoid forcing the initialization code in the static constructor to run as soon as
/// the type is referenced, even if that is done just to use methods
/// that do not actually require this instance to be available.
/// We're effectively using this type to leverage the lazy loading of static constructors done by the runtime.
///
private static class MethodInfos
{
///
/// The instance associated with .
///
public static readonly MethodInfo RegisterIRecipient = new Action, Unit>(Register).Method.GetGenericMethodDefinition();
}
///
/// A non-generic version of .
///
private static class DiscoveredRecipients
{
///
/// The instance used to track the preloaded registration action for each recipient.
///
public static readonly ConditionalWeakTable?> RegistrationMethods = new();
}
///
/// A class that acts as a static container to associate a instance to each
/// type in use. This is done because we can only use a single type as key, but we need to track
/// associations of each recipient type also across different communication channels, each identified by a token.
/// Since the token is actually a compile-time parameter, we can use a wrapping class to let the runtime handle a different
/// instance for each generic type instantiation. This lets us only worry about the recipient type being inspected.
///
/// The token indicating what channel to use.
private static class DiscoveredRecipients
where TToken : IEquatable
{
///
/// The instance used to track the preloaded registration action for each recipient.
///
public static readonly ConditionalWeakTable> RegistrationMethods = new();
}
///
/// Checks whether or not a given recipient has already been registered for a message.
///
/// The type of message to check for the given recipient.
/// The instance to use to check the registration.
/// The target recipient to check the registration for.
/// Whether or not has already been registered for the specified message.
/// This method will use the default channel to check for the requested registration.
/// Thrown if or are .
public static bool IsRegistered(this IMessenger messenger, object recipient)
where TMessage : class
{
ArgumentNullException.ThrowIfNull(messenger);
ArgumentNullException.ThrowIfNull(recipient);
return messenger.IsRegistered(recipient, default);
}
///
/// Registers all declared message handlers for a given recipient, using the default channel.
///
/// The instance to use to register the recipient.
/// The recipient that will receive the messages.
/// See notes for for more info.
/// Thrown if or are .
[RequiresUnreferencedCode(
"This method requires the generated CommunityToolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions type not to be removed to use the fast path. " +
"If this type is removed by the linker, or if the target recipient was created dynamically and was missed by the source generator, a slower fallback " +
"path using a compiled LINQ expression will be used. This will have more overhead in the first invocation of this method for any given recipient type.")]
[RequiresDynamicCode(
"This method requires the generated CommunityToolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions type not to be removed to use the fast path. " +
"If that is present, the method is AOT safe, as the only methods being invoked to register the messages will be the ones produced by the source generator. " +
"If it isn't, this method will need to dynamically create the generic methods to register messages, which might not be available at runtime.")]
public static void RegisterAll(this IMessenger messenger, object recipient)
{
ArgumentNullException.ThrowIfNull(messenger);
ArgumentNullException.ThrowIfNull(recipient);
// We use this method as a callback for the conditional weak table, which will handle
// thread-safety for us. This first callback will try to find a generated method for the
// target recipient type, and just invoke it to get the delegate to cache and use later.
[RequiresUnreferencedCode("The type of the current instance cannot be statically discovered.")]
static Action? LoadRegistrationMethodsForType(Type recipientType)
{
if (recipientType.Assembly.GetType("CommunityToolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType &&
extensionsType.GetMethod("CreateAllMessagesRegistrator", new[] { recipientType }) is MethodInfo methodInfo)
{
return (Action)methodInfo.Invoke(null, new object?[] { null })!;
}
return null;
}
// Try to get the cached delegate, if the generator has run correctly
Action? registrationAction = DiscoveredRecipients.RegistrationMethods.GetValue(
recipient.GetType(),
LoadRegistrationMethodsForType);
if (registrationAction is not null)
{
registrationAction(messenger, recipient);
}
else
{
messenger.RegisterAll(recipient, default(Unit));
}
}
///
/// Registers all declared message handlers for a given recipient.
///
/// The type of token to identify what channel to use to receive messages.
/// The instance to use to register the recipient.
/// The recipient that will receive the messages.
/// The token indicating what channel to use.
///
/// This method will register all messages corresponding to the interfaces
/// being implemented by . If none are present, this method will do nothing.
/// Note that unlike all other extensions, this method will use reflection to find the handlers to register.
/// Once the registration is complete though, the performance will be exactly the same as with handlers
/// registered directly through any of the other generic extensions for the interface.
///
/// Thrown if , or are .
[RequiresUnreferencedCode(
"This method requires the generated CommunityToolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions type not to be removed to use the fast path. " +
"If this type is removed by the linker, or if the target recipient was created dynamically and was missed by the source generator, a slower fallback " +
"path using a compiled LINQ expression will be used. This will have more overhead in the first invocation of this method for any given recipient type.")]
[RequiresDynamicCode("The generic methods to register messages might not be available at runtime.")]
public static void RegisterAll(this IMessenger messenger, object recipient, TToken token)
where TToken : IEquatable
{
ArgumentNullException.ThrowIfNull(messenger);
ArgumentNullException.ThrowIfNull(recipient);
ArgumentNullException.For.ThrowIfNull(token);
// We use this method as a callback for the conditional weak table, which will handle
// thread-safety for us. This first callback will try to find a generated method for the
// target recipient type, and just invoke it to get the delegate to cache and use later.
// In this case we also need to create a generic instantiation of the target method first.
[RequiresUnreferencedCode("The type of the current instance cannot be statically discovered.")]
[RequiresDynamicCode("The generic methods to register messages might not be available at runtime.")]
static Action LoadRegistrationMethodsForType(Type recipientType)
{
if (recipientType.Assembly.GetType("CommunityToolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType &&
extensionsType.GetMethod("CreateAllMessagesRegistratorWithToken", new[] { recipientType }) is MethodInfo methodInfo)
{
MethodInfo genericMethodInfo = methodInfo.MakeGenericMethod(typeof(TToken));
return (Action)genericMethodInfo.Invoke(null, new object?[] { null })!;
}
return LoadRegistrationMethodsForTypeFallback(recipientType);
}
// Fallback method when a generated method is not found.
// This method is only invoked once per recipient type and token type, so we're not
// worried about making it super efficient, and we can use the LINQ code for clarity.
// The LINQ codegen bloat is not really important for the same reason.
[RequiresDynamicCode("The generic methods to register messages might not be available at runtime.")]
static Action LoadRegistrationMethodsForTypeFallback(Type recipientType)
{
// Get the collection of validation methods
MethodInfo[] registrationMethods = (
from interfaceType in recipientType.GetInterfaces()
where interfaceType.IsGenericType &&
interfaceType.GetGenericTypeDefinition() == typeof(IRecipient<>)
let messageType = interfaceType.GenericTypeArguments[0]
select MethodInfos.RegisterIRecipient.MakeGenericMethod(messageType, typeof(TToken))).ToArray();
// Short path if there are no message handlers to register
if (registrationMethods.Length == 0)
{
return static (_, _, _) => { };
}
// Input parameters (IMessenger instance, non-generic recipient, token)
ParameterExpression arg0 = Expression.Parameter(typeof(IMessenger));
ParameterExpression arg1 = Expression.Parameter(typeof(object));
ParameterExpression arg2 = Expression.Parameter(typeof(TToken));
// Declare a local resulting from the (RecipientType)recipient cast
UnaryExpression inst1 = Expression.Convert(arg1, recipientType);
// We want a single compiled LINQ expression that executes the registration for all
// the declared message types in the input type. To do so, we create a block with the
// unrolled invocations for the individual message registration (for each IRecipient).
// The code below will generate the following block expression:
// ===============================================================================
// {
// var inst1 = (RecipientType)arg1;
// IMessengerExtensions.Register(arg0, inst1, arg2);
// IMessengerExtensions.Register(arg0, inst1, arg2);
// ...
// IMessengerExtensions.Register(arg0, inst1, arg2);
// }
// ===============================================================================
// We also add an explicit object conversion to cast the input recipient type to
// the actual specific type, so that the exposed message handlers are accessible.
BlockExpression body = Expression.Block(
from registrationMethod in registrationMethods
select Expression.Call(registrationMethod, new Expression[]
{
arg0,
inst1,
arg2
}));
return Expression.Lambda>(body, arg0, arg1, arg2).Compile();
}
// Get or compute the registration method for the current recipient type.
// As in CommunityToolkit.Diagnostics.TypeExtensions.ToTypeString, we use a lambda
// expression instead of a method group expression to leverage the statically initialized
// delegate and avoid repeated allocations for each invocation of this method.
// For more info on this, see the related issue at https://github.com/dotnet/roslyn/issues/5835.
Action registrationAction = DiscoveredRecipients.RegistrationMethods.GetValue(
recipient.GetType(),
LoadRegistrationMethodsForType);
// Invoke the cached delegate to actually execute the message registration
registrationAction(messenger, recipient, token);
}
///
/// Registers a recipient for a given type of message.
///
/// The type of message to receive.
/// The instance to use to register the recipient.
/// The recipient that will receive the messages.
/// Thrown when trying to register the same message twice.
/// This method will use the default channel to perform the requested registration.
/// Thrown if or are .
public static void Register(this IMessenger messenger, IRecipient recipient)
where TMessage : class
{
ArgumentNullException.ThrowIfNull(messenger);
ArgumentNullException.ThrowIfNull(recipient);
if (messenger is WeakReferenceMessenger weakReferenceMessenger)
{
weakReferenceMessenger.Register(recipient, default);
}
else if (messenger is StrongReferenceMessenger strongReferenceMessenger)
{
strongReferenceMessenger.Register(recipient, default);
}
else
{
messenger.Register, TMessage, Unit>(recipient, default, static (r, m) => r.Receive(m));
}
}
///
/// Registers a recipient for a given type of message.
///
/// The type of message to receive.
/// The type of token to identify what channel to use to receive messages.
/// The instance to use to register the recipient.
/// The recipient that will receive the messages.
/// The token indicating what channel to use.
/// Thrown when trying to register the same message twice.
/// This method will use the default channel to perform the requested registration.
/// Thrown if , or are .
public static void Register(this IMessenger messenger, IRecipient recipient, TToken token)
where TMessage : class
where TToken : IEquatable
{
ArgumentNullException.ThrowIfNull(messenger);
ArgumentNullException.ThrowIfNull(recipient);
ArgumentNullException.For.ThrowIfNull(token);
if (messenger is WeakReferenceMessenger weakReferenceMessenger)
{
weakReferenceMessenger.Register(recipient, token);
}
else if (messenger is StrongReferenceMessenger strongReferenceMessenger)
{
strongReferenceMessenger.Register(recipient, token);
}
else
{
messenger.Register, TMessage, TToken>(recipient, token, static (r, m) => r.Receive(m));
}
}
///
/// Registers a recipient for a given type of message.
///
/// The type of message to receive.
/// The instance to use to register the recipient.
/// The recipient that will receive the messages.
/// The to invoke when a message is received.
/// Thrown when trying to register the same message twice.
/// This method will use the default channel to perform the requested registration.
/// Thrown if , or are .
public static void Register(this IMessenger messenger, object recipient, MessageHandler