はじめに #
Pythonのdataclassを継承して派生クラスを作る方法を解説します。
dataclassとは、データを格納するための特殊なクラスです。
また、初期値・型ヒント・メソッドのオーバーライドや、型ヒントの付け方についても解説します。
※この記事では、基本的なdataclassの使い方(定義やオブジェクトの作成方法)について分かっていることを前提とします。
検証環境
- Python 3.11.6
- mypy 1.10.0
基本的な継承 #
dataclassを継承する基本的な方法を以下に示します。まず、変数x, yを持つ Point2D を定義します。さらに、このクラスを継承するサブクラス Point3Dを定義します。
from dataclasses import dataclass
@dataclass
class Point2D:
x: float
y: float
@dataclass
class Point3D(Point2D):
z: float
Point3Dの後ろの()の中にPoint2Dを記述することで継承できます。また、Point3Dにもデコレータ@dataclassを付けます。
以下の通り、Point3DはスーパークラスPoint2D の変数x, yに加え、変数zを持ちます。
>>> Point3D?
Init signature: Point3D(x: float, y: float, z: float) -> None
Docstring: Point3D(x: float, y: float, z: float)
File: ...
Type: type
Subclasses:
>>> Point3D(1, 2, 3)
Point3D(x=1, y=2, z=3)
初期値・型ヒントのオーバーライド #
サブクラスでは、スーパークラスで設定した初期値や型ヒントをオーバーライド(上書き)できます。
@dataclass
class Point2D:
x: float = 0
y: float = 0
@dataclass
class Point3D(Point2D):
x: int = 3
z: float = 3
以下のようにxの初期値を3, 型ヒントをintに上書きできました。
>>> Point3D?
Init signature: Point3D(x: int = 3, y: float = 0, z: float = 3) -> None
Docstring: Point3D(x: int = 3, y: float = 0, z: float = 3)
File: ...
Type: type
Subclasses:
メソッドのオーバーライド #
dataclassには、__post_init__などのメソッドを記述できます。なお、__post_init__はオブジェクトを作成するときに自動実行される、初期化処理を記述するメソッドです。
dataclassを継承するとき、メソッドを再記述することでオーバーライドできます。
以下のように、点の原点からの距離(ユークリッド距離)を格納する属性distanceを定義します。さらに、__post_init__メソッドでdistanceを計算します。
2次元座標と3次元座標で計算方法が変わるため、Point3Dクラスで__post_init__を再定義します。
from dataclasses import dataclass, field
@dataclass
class Point2D:
x: float
y: float
distance: float = field(init=False)
def __post_init__(self):
self.distance = (self.x**2 + self.y**2)**0.5
@dataclass
class Point3D(Point2D):
z: float
def __post_init__(self):
self.distance = (self.x**2 + self.y**2 + self.z**2)**0.5
以下のように、distanceが正しく計算されています。
>>> p2 = Point2D(3, 4)
>>> p2.distance
5.0
>>> p3 = Point3D(1, 2, 2)
>>> p3.distance
3.0
親クラスのメソッド呼び出し #
通常のクラスの継承のように、dataclassのサブクラスにおいても、親クラスのメソッドをsuper()を使用して呼び出すことができます。
以下のPoint1D, Point2Dクラスでは、それぞれ引数を正としたいです。Point1Dクラスでは、__post_init__メソッドでxが正であることをチェックしています。
一方、Point2Dクラスでは、super().__post_init__()を実行することで、スーパークラス (Point1D) の__post_init__メソッドを呼び出すことができます。
そのため、Point2Dクラスにはxの正負をチェックする記述を書き直す必要がなく、yの正負をチェックする処理を書くだけで済みます。
(この例ではあまりメリットはありませんが、スーパークラス内の処理が長く複雑になるほど恩恵は大きくなります)
@dataclass
class Point1D:
"""引数xは正とすること"""
x: float
def __post_init__(self):
"""xが正であることをチェック"""
if self.x <= 0:
raise ValueError
@dataclass
class Point2D(Point1D):
"""引数x, yは正とすること"""
y: float
def __post_init__(self):
"""x, yが正であることをチェック"""
super().__post_init__()
if self.y <= 0:
raise ValueError
以下にエラーを発生させた例を示します。
サブクラスPoint2Dのxに-1を与えると、Point1Dクラスの__post_init__メソッドでValueErrorを発生したことが分かります。
>>> p1 = Point1D(x=-1) # ValueErrorが発生する。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
----> 1 p1 = Point1D(x=-1)
5 def __post_init__(self):
6 if self.x < 0:
----> 7 raise ValueError
>>> p2 = Point2D(x=-1, y=1) # x=-1に対してValueErrorが発生する。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
----> 1 p2 = Point2D(x=-1, y=1)
13 def __post_init__(self):
---> 14 super().__post_init__()
16 if self.y < 0:
17 raise ValueError
5 def __post_init__(self):
6 if self.x < 0:
----> 7 raise ValueError
継承と型チェック #
Pythonでは変数に型ヒントを付けることができます。dataclassで定義したクラスについても、以下のように型ヒントに出来ます。
def my_func(point: Point2D):
...
Pythonの有名な型チェッカとしてmypyがあります。mypyでは、dataclassの継承について以下のように判定します。
- 型ヒントがスーパークラスのとき、スーパークラスとサブクラスのどちらを与えても問題ない。
- 型ヒントがサブクラスのとき、スーパークラスを与えるとエラーとして検出。
スーパークラスよりもそれを継承したサブクラスの方が一般に属性が多いため、この仕様は自明です。
OKな例とNGな例をそれぞれ示します。
以下はOKな例です。スーパークラスPoint1Dの型ヒントを持つ関数print_1xに対し、サブクラスPoint2Dのオブジェクトを与えています。
mypyによるチェックを実行してもエラーは検出されませんでした。
from dataclasses import dataclass
@dataclass
class Point1D:
x: float
@dataclass
class Point2D(Point1D):
y: float
def print_1x(p: Point1D):
print(p.x)
if __name__=="__main__":
p2 = Point2D(1, 2)
print_1x(p2)
> mypy test_ok.py
Success: no issues found in 1 source file
次に、以下はNGな例です。サブクラスPoint2Dの型ヒントを持つ関数print_x2に対し、スーパークラスPoint1Dのオブジェクトを与えています。
mypyによってエラーが検出されています。
from dataclasses import dataclass
@dataclass
class Point1D:
x: float
@dataclass
class Point2D(Point1D):
y: float
def print_x2(p: Point2D):
print(p.x)
if __name__=="__main__":
p1 = Point1D(1)
print_x2(p1)
> mypy test_ng.py
mypy_test_ng.py:16: error: Argument 1 to
"print_x2" has incompatible type "Point1D";
expected "Point2D" [arg-type]
Found 1 error in 1 file (checked 1 source file)
なお、関数print_x2ではx属性にしかアクセスしていないため、上記のtest_ng.pyというコードを実行すること自体は可能です。
ただし、将来コードを修正したときにバグを埋め込んでしまう可能性があるため、望ましい状態ではありません。