Refactoring a big if block into a simple command processor using attributes

Written by Troy Howard

15 May 2008

Recently someone had a problem where they had some massive control block full of if statements looking at a string, dispatching one of a variety of functions. The if block was massive. Hundreds of if statements, hundreds of magic strings.

Interestingly all the functions had the same signature… So I gave him this example of how to use attributes on the methods to specify the corresponding token, use Reflection to scan the assembly for all the functions with that attribute, then create a function table keyed by their token, to provide fast lookup. This example shows how to creat an object instance and then invoke the method via reflection, but this could be made much simpler if the methods were all static and the function protoype was part of an interface instead of just a unspoken convention.

Here's the "Before" example from the original question…

string tag; 
string cmdLine; 
State state; 
string outData; 

// ... etc ... 

if (token == "ABCSearch") {
    ABC abc = new ABC(); 
    abc.SearchFor(tag, state, cmdLine, ref outData); 
} 
else if (token == "JklmDoSomething") {
    JKLM jklm = new JKLM(); 
    jklm.Dowork1(tag, state, cmdLine, ref outData); 
}

A couple of notes:

  • There is no correlation between the token and the class name (ABC, JKLM, …) or the method (SearchFor, Dowork1).
  • The methods do have the same signature:

    void func(string tag, State state, string cmdLine, ref string outData)
    
  • The if ()... block is 500+ lines and growing

And here is my example command processor (as a console app):

using System;
using System.Collections.Generic;
using System.Reflection;

namespace ConsoleApplication2
{
    public class Program
    {
        static void Main(string[] args)
        {
            while(true)
            {
                Console.Write("[e(x)ecute, (t)okens, (q)uit] -> ");
                string s = Console.ReadKey().KeyChar.ToString().ToLower();
                Console.WriteLine();

                switch (s)
                {
                case "q":
                    Console.WriteLine("Finished.");
                    return;

                case "t":
                    Console.WriteLine("Known tokens:");
                    foreach (string tokenName in CommandProcessor.GetTokens())
                    {
                        Console.WriteLine(tokenName);
                    }
                    break;

                case "x":
                    string token = string.Empty;
                    string tag = string.Empty;
                    string cmdLine = string.Empty;
                    string state = string.Empty;

                    Console.Write("token: ");
                    token = Console.ReadLine();
                    Console.Write("tag: ");
                    tag = Console.ReadLine();
                    Console.Write("cmdLine: ");
                    cmdLine = Console.ReadLine();
                    Console.Write("state: ");
                    state = Console.ReadLine();

                    try
                    {
                        string output = CommandProcessor.DoCommand(token, tag, cmdLine, State.GetStateFromString(state));
                        Console.WriteLine("Output:");
                        Console.WriteLine(output);
                    }
                    catch (TokenNotFoundException ex)
                    {
                        Console.WriteLine(ex.Message);
                    }
                    catch (Exception ex)
                    {
                        Console.WriteLine("Unknown error occured during execution. Exception was: " + ex.Message);
                    }
                    break;

                default:
                    Console.WriteLine("Unknown command: {0}", s);
                    break;
                }
            }
        }
    }

    public class CommandProcessor
    {
        // our dictionary of method calls.
        internal static Dictionary availableFunctions = new Dictionary();

        static CommandProcessor()
        {
            SetupMethodCallDictionary();
        }

        private static void SetupMethodCallDictionary()
        {
            // get the current assembly.
            Assembly assembly = Assembly.GetExecutingAssembly();

            // cycle through the types in the assembly
            foreach (Type type in assembly.GetTypes())
            {
                // cycle through the methods on each type
                foreach (MethodInfo method in type.GetMethods())
                {
                    // look for Token attributes on the methods.
                    object[] tokens = method.GetCustomAttributes(typeof(TokenAttribute), true);

                    if (tokens.Length > 0)
                    {
                        // cycle through the token attributes (allowing multiple attributes
                        // leaves room for backwards compatibility if you change your tokens
                        // or consolidate functionality of the methods. etc.
                        foreach (TokenAttribute token in tokens)
                        {
                            // look for the token in the dictionary, if it's not there add it..
                            MethodInfo foundMethod = default(MethodInfo);
                            if (availableFunctions.TryGetValue(token.TokenName, out foundMethod))
                            {
                                // if there is more than one function registered for the same
                                // token, just keep the last one found.
                                availableFunctions[token.TokenName] = method;
                            }
                            else
                            {
                                // add to the table.
                                availableFunctions.Add(token.TokenName, method);
                            }
                        }
                    }
                }
            }
        }

        public static string DoCommand(string token, string tag, string cmdLine, State state)
        {
            // the data returned from the command
            string outData = string.Empty;
            MethodInfo method = default(MethodInfo);

            // see if we have a method for that token
            if (availableFunctions.TryGetValue(token, out method))
            {
                // if so, create an instance of the object, and then execute the method,
                // unless it's static.. in which case just execute the method.
                object instance = null;
                if (!method.IsStatic)
                {
                    // this just invokes the default constructor... if you need to pass
                    // parameters use one of the other overloads.
                    instance = Activator.CreateInstance(method.ReflectedType);
                }

                object[] args = new object[] { tag, state, cmdLine, outData };

                method.Invoke(instance, args);
                outData = (string)args[3];
            }
            else
            {
                throw new TokenNotFoundException(string.Format("Token {0} not found. Cannot execute.", token));
            }
            return outData;
        }

        public static IEnumerable GetTokens()
        {
            foreach (KeyValuePair entry in availableFunctions)
            {
                yield return entry.Key;
            }
        }
    }

    public class State
    {
        public State(string text)
        {
            _text = text;
        }

        private string _text;

        public string Text
        {
            get { return _text; }
            set { _text = value; }
        }

        public static State GetStateFromString(string state)
        {
            // implement parsing of string to build State object here.
            return new State(state);
        }
    }

    [AttributeUsage(AttributeTargets.Method)]
    public class TokenAttribute : Attribute
    {
        public TokenAttribute(string tokenName)
        {
            _tokenName = tokenName;
        }

        private string _tokenName;

        public string TokenName
        {
            get { return _tokenName; }
            set { _tokenName = value; }
        }
    }

    [global::System.Serializable]
    public class TokenNotFoundException : Exception
    {
        //
        // For guidelines regarding the creation of new exception types, see
        // http://msdn.microsoft.com/library/default.asp?url=/library/en-us/cpgenref/html/cpconerrorraisinghandlingguidelines.asp
        // and
        // http://msdn.microsoft.com/library/default.asp?url=/library/en-us/dncscol/html/csharp07192001.asp
        //
        public TokenNotFoundException() { }
        public TokenNotFoundException(string message) : base(message) { }
        public TokenNotFoundException(string message, Exception inner) : base(message, inner) { }
        protected TokenNotFoundException(
        System.Runtime.Serialization.SerializationInfo info,
        System.Runtime.Serialization.StreamingContext context)
        : base(info, context) { }
    }

    public class ABC
    {
        [Token("ABCSearch")]
        public void SearchFor(string tag, State state, string cmdLine, ref string outData)
        {
            // do some stuff.
            outData =
            string.Format("You called ABC.SearchFor. Parameters were [tag: {0}, state: {1}, cmdLine: {2}]", tag, state.Text, cmdLine);

       }
    }

    public class JKLM
    {
        [Token("JklmDoSomething")]
        public void Dowork1(string tag, State state, string cmdLine, ref string outData)
        {
            // do some other stuff.
            outData =
            string.Format("You called JKLM.Dowork1. Parameters were [tag: {0}, state: {1}, cmdLine: {2}]", tag, state.Text, cmdLine);
        }
    }
}