Refactoring a big if block into a simple command processor using attributes

Written by Troy Howard

15 May 2008

Recently someone had a problem where they had some massive control block full of if statements looking at a string, dispatching one of a variety of functions. The if block was massive. Hundreds of if statements, hundreds of magic strings.

Interestingly all the functions had the same signature... So I gave him this example of how to use attributes on the methods to specify the corresponding token, use Reflection to scan the assembly for all the functions with that attribute, then create a function table keyed by their token, to provide fast lookup. This example shows how to creat an object instance and then invoke the method via reflection, but this could be made much simpler if the methods were all static and the function protoype was part of an interface instead of just a unspoken convention.

Here's the "Before" example from the original question...

 1string tag; 
 2string cmdLine; 
 3State state; 
 4string outData; 
 5
 6// ... etc ... 
 7
 8if (token == "ABCSearch") {
 9    ABC abc = new ABC(); 
10    abc.SearchFor(tag, state, cmdLine, ref outData); 
11} 
12else if (token == "JklmDoSomething") {
13    JKLM jklm = new JKLM(); 
14    jklm.Dowork1(tag, state, cmdLine, ref outData); 
15} 

_ A couple of notes:_

  • There is no correlation between the token and the class name (ABC, JKLM, ...) or the method (SearchFor, Dowork1).

  • The methods do have the same signature:

    1void func(string tag, State state, string cmdLine, ref string outData) 
    
  • The if ()... block is 500+ lines and growing   And here is my example command processor (as a console app):

  1using System;
  2using System.Collections.Generic;
  3using System.Reflection;
  4 
  5namespace ConsoleApplication2
  6{
  7    public class Program
  8    {
  9        static void Main(string[] args)
 10        {
 11            while(true)
 12            {
 13                Console.Write("[e(x)ecute, (t)okens, (q)uit] -> ");
 14                string s = Console.ReadKey().KeyChar.ToString().ToLower();
 15                Console.WriteLine();
 16 
 17                switch (s)
 18                {
 19                case "q":
 20                    Console.WriteLine("Finished.");
 21                    return;
 22 
 23                case "t":
 24                    Console.WriteLine("Known tokens:");
 25                    foreach (string tokenName in CommandProcessor.GetTokens())
 26                    {
 27                        Console.WriteLine(tokenName);
 28                    }
 29                    break;
 30 
 31                case "x":
 32                    string token = string.Empty;
 33                    string tag = string.Empty;
 34                    string cmdLine = string.Empty;
 35                    string state = string.Empty;
 36 
 37                    Console.Write("token: ");
 38                    token = Console.ReadLine();
 39                    Console.Write("tag: ");
 40                    tag = Console.ReadLine();
 41                    Console.Write("cmdLine: ");
 42                    cmdLine = Console.ReadLine();
 43                    Console.Write("state: ");
 44                    state = Console.ReadLine();
 45 
 46                    try
 47                    {
 48                        string output = CommandProcessor.DoCommand(token, tag, cmdLine, State.GetStateFromString(state));
 49                        Console.WriteLine("Output:");
 50                        Console.WriteLine(output);
 51                    }
 52                    catch (TokenNotFoundException ex)
 53                    {
 54                        Console.WriteLine(ex.Message);
 55                    }
 56                    catch (Exception ex)
 57                    {
 58                        Console.WriteLine("Unknown error occured during execution. Exception was: " + ex.Message);
 59                    }
 60                    break;
 61 
 62                default:
 63                    Console.WriteLine("Unknown command: {0}", s);
 64                    break;
 65                }
 66            }
 67        }
 68    }
 69 
 70    public class CommandProcessor
 71    {
 72        // our dictionary of method calls.
 73        internal static Dictionary availableFunctions = new Dictionary();
 74 
 75        static CommandProcessor()
 76        {
 77            SetupMethodCallDictionary();
 78        }
 79 
 80        private static void SetupMethodCallDictionary()
 81        {
 82            // get the current assembly.
 83            Assembly assembly = Assembly.GetExecutingAssembly();
 84 
 85            // cycle through the types in the assembly
 86            foreach (Type type in assembly.GetTypes())
 87            {
 88                // cycle through the methods on each type
 89                foreach (MethodInfo method in type.GetMethods())
 90                {
 91                    // look for Token attributes on the methods.
 92                    object[] tokens = method.GetCustomAttributes(typeof(TokenAttribute), true);
 93 
 94                    if (tokens.Length > 0)
 95                    {
 96                        // cycle through the token attributes (allowing multiple attributes
 97                        // leaves room for backwards compatibility if you change your tokens
 98                        // or consolidate functionality of the methods. etc.
 99                        foreach (TokenAttribute token in tokens)
100                        {
101                            // look for the token in the dictionary, if it's not there add it..
102                            MethodInfo foundMethod = default(MethodInfo);
103                            if (availableFunctions.TryGetValue(token.TokenName, out foundMethod))
104                            {
105                                // if there is more than one function registered for the same
106                                // token, just keep the last one found.
107                                availableFunctions[token.TokenName] = method;
108                            }
109                            else
110                            {
111                                // add to the table.
112                                availableFunctions.Add(token.TokenName, method);
113                            }
114                        }
115                    }
116                }
117            }
118        }
119 
120        public static string DoCommand(string token, string tag, string cmdLine, State state)
121        {
122            // the data returned from the command
123            string outData = string.Empty;
124            MethodInfo method = default(MethodInfo);
125 
126            // see if we have a method for that token
127            if (availableFunctions.TryGetValue(token, out method))
128            {
129                // if so, create an instance of the object, and then execute the method,
130                // unless it's static.. in which case just execute the method.
131                object instance = null;
132                if (!method.IsStatic)
133                {
134                    // this just invokes the default constructor... if you need to pass
135                    // parameters use one of the other overloads.
136                    instance = Activator.CreateInstance(method.ReflectedType);
137                }
138 
139                object[] args = new object[] { tag, state, cmdLine, outData };
140 
141                method.Invoke(instance, args);
142                outData = (string)args[3];
143            }
144            else
145            {
146                throw new TokenNotFoundException(string.Format("Token {0} not found. Cannot execute.", token));
147            }
148            return outData;
149        }
150 
151        public static IEnumerable GetTokens()
152        {
153            foreach (KeyValuePair entry in availableFunctions)
154            {
155                yield return entry.Key;
156            }
157        }
158    }
159 
160    public class State
161    {
162        public State(string text)
163        {
164            _text = text;
165        }
166 
167        private string _text;
168 
169        public string Text
170        {
171            get { return _text; }
172            set { _text = value; }
173        }
174 
175        public static State GetStateFromString(string state)
176        {
177            // implement parsing of string to build State object here.
178            return new State(state);
179        }
180    }
181 
182    [AttributeUsage(AttributeTargets.Method)]
183    public class TokenAttribute : Attribute
184    {
185        public TokenAttribute(string tokenName)
186        {
187            _tokenName = tokenName;
188        }
189 
190        private string _tokenName;
191 
192        public string TokenName
193        {
194            get { return _tokenName; }
195            set { _tokenName = value; }
196        }
197    }
198 
199    [global::System.Serializable]
200    public class TokenNotFoundException : Exception
201    {
202        //
203        // For guidelines regarding the creation of new exception types, see
204        // http://msdn.microsoft.com/library/default.asp?url=/library/en-us/cpgenref/html/cpconerrorraisinghandlingguidelines.asp
205        // and
206        // http://msdn.microsoft.com/library/default.asp?url=/library/en-us/dncscol/html/csharp07192001.asp
207        //
208        public TokenNotFoundException() { }
209        public TokenNotFoundException(string message) : base(message) { }
210        public TokenNotFoundException(string message, Exception inner) : base(message, inner) { }
211        protected TokenNotFoundException(
212        System.Runtime.Serialization.SerializationInfo info,
213        System.Runtime.Serialization.StreamingContext context)
214        : base(info, context) { }
215    }
216 
217    public class ABC
218    {
219        [Token("ABCSearch")]
220        public void SearchFor(string tag, State state, string cmdLine, ref string outData)
221        {
222            // do some stuff.
223            outData =
224            string.Format("You called ABC.SearchFor. Parameters were [tag: {0}, state: {1}, cmdLine: {2}]", tag, state.Text, cmdLine);
225 
226       }
227    }
228 
229    public class JKLM
230    {
231        [Token("JklmDoSomething")]
232        public void Dowork1(string tag, State state, string cmdLine, ref string outData)
233        {
234            // do some other stuff.
235            outData =
236            string.Format("You called JKLM.Dowork1. Parameters were [tag: {0}, state: {1}, cmdLine: {2}]", tag, state.Text, cmdLine);
237        }
238    }
239}
240