Python中动态创建类的方法

0x00 前言

在Python中,类也是作为一种对象存在的,因此可以在运行时动态创建类,这也是Python灵活性的一种体现。

本文介绍了如何使用type动态创建类,以及相关的一些使用方法与技巧。

0x01 类的本质

何为类?类是对现实生活中一类具有共同特征的事物的抽象,它描述了所创建的对象共同的属性和方法。在常见的编译型语言(如C++)中,类在编译的时候就已经确定了,运行时是无法动态创建的。那么Python是如何做到的呢?

来看下面这段代码:

  1. class A(object):
  2. pass
  3. print(A)
  4. print(A.__class__)
COPY

在Python2中执行结果如下:

  1. <class '__main__.A'>
  2. <type 'type'>
COPY

在Python3中执行结果如下:

  1. <class '__main__.A'>
  2. <class 'type'>
COPY

可以看出,类A的类型是type,也就是说:type实例化后是实例化后是对象

0x02 使用type动态创建类

type的参数定义如下:

type(name, bases, dict)

name: 生成的类名
bases: 生成的类基类列表,类型为tuple
dict: 生成的类中包含的属性或方法

例如:可以使用以下方法创建一个类A

  1. cls = type('A', (object,), {'__doc__': 'class created by type'})
  2. print(cls)
  3. print(cls.__doc__)
COPY

输出结果如下:

  1. <class '__main__.A'>
  2. class created by type
COPY

可以看出,这样创建的类与静态定义的类基本没有什么差别,使用上还更灵活。

这种方法的使用场景之一是:

有些地方需要传入一个类作为参数,但是类中会用到某些受外界影响的变量;虽然使用全局变量可以解决这个问题,但是比较丑陋。此时,就可以使用这种方法动态创建一个类来使用。

以下是一个使用的示例:

  1. import socket
  2. try:
  3. import SocketServer
  4. except ImportError:
  5. # python3
  6. import socketserver as SocketServer
  7. class PortForwardingRequestHandler(SocketServer.BaseRequestHandler):
  8. '''处理端口转发请求
  9. '''
  10. def handle(self):
  11. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  12. sock.connect(self.server) # self.server是在动态创建类时传入的
  13. # 连接目标服务器,并转发数据
  14. # 以下代码省略...
  15. def gen_cls(server):
  16. '''动态创建子类
  17. '''
  18. return type('%s_%s' % (ProxyRequestHandler.__name__, server), (PortForwardingRequestHandler, object), {'server': server})
  19. server = SocketServer.ThreadingTCPServer(('127.0.0.1', 8080), gen_cls(('www.qq.com', 80)))
  20. server.serve_forever()
COPY

在上面的例子中,由于目标服务器地址是由用户传入的,而PortForwardingRequestHandler类的实例化是在ThreadingTCPServer里实现的,我们没法控制。因此,使用动态创建类的方法可以很好地解决这个问题。

0x03 使用元类(metaclass

类是实例的模版,而元类是类的模版。通过元类可以创建出类,类的默认元类是type,所有元类必须是type的子类。

下面是元类的一个例子:

  1. import struct
  2. class MetaClass(type):
  3. def __init__(cls, name, bases, attrd):
  4. super(MetaClass, cls).__init__(name, bases, attrd)
  5. def __mul__(self, num):
  6. return type('%s_Array_%d' % (self.__name__, num), (ArrayTypeBase,), {'obj_type': self, 'array_size': num, 'size': self.size * num})
  7. class IntTypeBase(object):
  8. '''类型基类
  9. '''
  10. __metaclass__ = MetaClass
  11. size = 0
  12. format = '' # strcut格式
  13. def __init__(self, val=0):
  14. if isinstance(val, str): val = int(val)
  15. if not isinstance(val, int):
  16. raise TypeError('类型错误:%s' % type(val))
  17. self._net_order = True # 默认存储的为网络序数据
  18. self.value = val
  19. self._num = 1
  20. def __str__(self):
  21. return '%d(%s)' % (self._val, self.__class__.__name__)
  22. def __cmp__(self, val):
  23. if isinstance(val, IntTypeBase):
  24. return cmp(self.value, val.value)
  25. elif isinstance(val, (int, long)):
  26. return cmp(self.value, val)
  27. elif isinstance(val, type(None)):
  28. return cmp(int(self.value), None)
  29. else:
  30. raise TypeError('类型错误:%s' % type(val))
  31. def __int__(self):
  32. return int(self.value)
  33. def __hex__(self):
  34. return hex(self.value)
  35. def __index__(self):
  36. return self.value
  37. def __add__(self, val):
  38. return int(self.value + val)
  39. def __radd__(self, val):
  40. return int(val + self.value)
  41. def __sub__(self, val):
  42. return self.value - val
  43. def __rsub__(self, val):
  44. return val - self.value
  45. def __mul__(self, val):
  46. return self.value * val
  47. def __div__(self, val):
  48. return self.value / val
  49. def __mod__(self, val):
  50. return self.value % val
  51. def __rshift__(self, val):
  52. return self.value >> val
  53. def __and__(self, val):
  54. return self.value & val
  55. @property
  56. def net_order(self):
  57. return self._net_order
  58. @net_order.setter
  59. def net_order(self, _net_order):
  60. self._net_order = _net_order
  61. @property
  62. def value(self):
  63. return self._val
  64. @value.setter
  65. def value(self, val):
  66. if not isinstance(val, int):
  67. raise TypeError('类型错误:%s' % type(val))
  68. if val < 0: raise ValueError(val)
  69. max_val = 256 ** (self.size) - 1
  70. if val > max_val: raise ValueError('%d超过最大大小%d' % (val, max_val))
  71. self._val = val
  72. def unpack(self, buff, net_order=True):
  73. '''从buffer中提取出数据
  74. '''
  75. if len(buff) < self.size: raise ValueError(repr(buff))
  76. buff = buff[:self.size]
  77. fmt = self.format
  78. if not net_order: fmt = '<' + fmt[1]
  79. self._val = struct.unpack(fmt, buff)[0]
  80. return self._val
  81. def pack(self, net_order=True):
  82. '''返回内存数据
  83. '''
  84. fmt = self.format
  85. if not net_order: fmt = '<' + fmt[1]
  86. return struct.pack(fmt, self._val)
  87. @staticmethod
  88. def cls_from_size(size):
  89. '''从整型大小返回对应的类
  90. '''
  91. if size == 1:
  92. return c_uint8
  93. elif size == 2:
  94. return c_uint16
  95. elif size == 4:
  96. return c_uint32
  97. elif size == 8:
  98. return c_uint64
  99. else:
  100. raise RuntimeError('不支持的整型数据长度:%d' % size)
  101. @classmethod
  102. def unpack_from(cls, str, net_order=True):
  103. obj = cls()
  104. obj.unpack(str, net_order)
  105. return int(obj)
  106. class ArrayTypeBase(object):
  107. '''数组类型基类
  108. '''
  109. def __init__(self, val=''):
  110. init_val = 0
  111. if isinstance(val, int):
  112. init_val = val
  113. else:
  114. val = str(val)
  115. self._obj_array = [self.obj_type(init_val) for _ in range(self.array_size)] # 初始化
  116. self.value = val
  117. def __str__(self):
  118. return str(self.value)
  119. def __repr__(self):
  120. return repr(self.value)
  121. def __getitem__(self, idx):
  122. return self._obj_array[idx].value
  123. def __setitem__(self, idx, val):
  124. self._obj_array[idx].value = val
  125. def __getslice__(self, i, j):
  126. result = [obj.value for obj in self._obj_array[i:j]]
  127. if self.obj_type == c_ubyte:
  128. result = [chr(val) for val in result]
  129. result = ''.join(result)
  130. return result
  131. def __add__(self, oval):
  132. if not isinstance(oval, str):
  133. raise NotImplementedError('%s还不支持%s类型' % (self.__class__.__name__, type(oval)))
  134. return self.value + oval
  135. def __radd__(self, oval):
  136. return oval + self.value
  137. def __iter__(self):
  138. '''迭代器
  139. '''
  140. for i in range(self.length):
  141. yield self[i]
  142. @property
  143. def value(self):
  144. result = [obj.value for obj in self._obj_array]
  145. if self.obj_type == c_ubyte:
  146. result = [chr(val) for val in result]
  147. result = ''.join(result)
  148. return result
  149. @value.setter
  150. def value(self, val):
  151. if isinstance(val, list):
  152. raise NotImplementedError('ArrayType还不支持list')
  153. elif isinstance(val, str):
  154. self.unpack(val)
  155. def unpack(self, buff, net_order=True):
  156. '''
  157. '''
  158. if len(buff) == 0: return
  159. if len(buff) < self.size: raise ValueError('unpack数据长度错误:%d %d' % (len(buff), self.size))
  160. for i in range(self.array_size):
  161. self._obj_array[i].unpack(buff[i * self.obj_type.size:], net_order)
  162. def pack(self, net_order=True):
  163. '''
  164. '''
  165. result = ''
  166. for i in range(self.array_size):
  167. result += self._obj_array[i].pack()
  168. return result
  169. class c_uint8(IntTypeBase):
  170. '''unsigned char
  171. '''
  172. size = 1
  173. format = '!B'
  174. class c_uint16(IntTypeBase):
  175. '''unsigned short
  176. '''
  177. size = 2
  178. format = '!H'
  179. class c_uint32(IntTypeBase):
  180. '''unsigned int32
  181. '''
  182. size = 4
  183. format = '!I'
  184. class c_uint64(IntTypeBase):
  185. '''unsigned int64
  186. '''
  187. size = 8
  188. format = '!Q'
  189. cls = c_ubyte * 5
  190. print(cls)
  191. val = cls(65)
  192. print(val)
COPY

以上代码在Python2.7中输出结果如下:

  1. <class '__main__.c_ubyte_Array_5'>
  2. AAAAA
COPY

在Python3中,metaclass的定义方法做了修改,变成了:

  1. class IntTypeBase(object, metaclass=MetaClass):
  2. pass
COPY

为了兼容性。可以使用six库中的方法:

  1. import six
  2. @six.add_metaclass(MetaClass)
  3. class IntTypeBase(object):
  4. pass
COPY

使用元类的优点是可以使用更加优雅的方式创建类,如上面的c_ubyte * 5,提升了代码可读性和技巧性。

0x04 重写__new__方法

每个继承自object的类都有__new__方法,这是个在类实例化时优先调用的方法,时机早于__init__。它返回的类型决定了最终创建出来的对象的类型。

请看以下代码:

  1. class A(object):
  2. def __new__(self, *args, **kwargs):
  3. return B()
  4. class B(object):
  5. pass
  6. a = A()
  7. print(a)
COPY

输出结果如下:

  1. <__main__.B object at 0x023576D0>
COPY

可以看到,明明实例化的是A,但是返回的对象类型却是B,这里主要就是__new__在起作用。

下面的例子展示了在__new__中动态创建类的过程:

  1. class B(object):
  2. def __init__(self, var):
  3. self._var = var
  4. def test(self):
  5. print(self._var)
  6. class A(object):
  7. def __new__(self, *args, **kwargs):
  8. if len(args) == 1 and isinstance(args[0], type):
  9. return type('%s_%s' % (self.__name__, args[0].__name__), (self, args[0]), {})
  10. else:
  11. return object.__new__(self, *args, **kwargs)
  12. def output(self):
  13. print('output from new class %s' % self.__class__.__name__)
  14. obj = A(B)('Hello World')
  15. obj.test()
  16. obj.output()
COPY

结果输出如下:

  1. Hello World
  2. output from new class A_B
COPY

这个例子实现了动态创建两个类的子类,比较适合存在很多类需要排列组合生成N多子类的场景,可以避免要写一堆子类代码的痛苦。

0x05 总结

动态创建类必须要使用type实现,但是,根据不同的使用场景,可以选择不同的使用方法。

这样做对静态分析工具其实是不友好的,因为在运行过程中类型发生了变化。而且,这也会降低代码的可读性,一般情况下也不推荐用户使用这样存在一定技巧性的代码。

分享

Related Issues not found

Please contact @drunkdream to initialize the comment