はじめに
Pythonのdataclass
を継承して派生クラスを作る方法を解説します。
dataclass
とは、データを格納するための特殊なクラスです。
また、初期値・型ヒント・メソッドのオーバーライドや、型ヒントの付け方についても解説します。
※この記事では、基本的なdataclass
の使い方(定義やオブジェクトの作成方法)について分かっていることを前提とします。
検証環境
- Python 3.11.6
- mypy 1.10.0
基本的な継承
dataclass
を継承する基本的な方法を以下に示します。まず、変数x
, y
を持つ Point2D
を定義します。さらに、このクラスを継承するサブクラス Point3D
を定義します。
1
2
3
4
5
6
7
8
9
10
|
from dataclasses import dataclass
@dataclass
class Point2D:
x: float
y: float
@dataclass
class Point3D(Point2D):
z: float
|
Point3D
の後ろの()
の中にPoint2D
を記述することで継承できます。また、Point3D
にもデコレータ@dataclass
を付けます。
以下の通り、Point3D
はスーパークラスPoint2D
の変数x
, y
に加え、変数z
を持ちます。
1
2
3
4
5
6
7
8
9
|
>>> Point3D?
Init signature: Point3D(x: float, y: float, z: float) -> None
Docstring: Point3D(x: float, y: float, z: float)
File: ...
Type: type
Subclasses:
>>> Point3D(1, 2, 3)
Point3D(x=1, y=2, z=3)
|
初期値・型ヒントのオーバーライド
サブクラスでは、スーパークラスで設定した初期値や型ヒントをオーバーライド(上書き)できます。
1
2
3
4
5
6
7
8
9
|
@dataclass
class Point2D:
x: float = 0
y: float = 0
@dataclass
class Point3D(Point2D):
x: int = 3
z: float = 3
|
以下のようにx
の初期値を3
, 型ヒントをint
に上書きできました。
1
2
3
4
5
6
|
>>> Point3D?
Init signature: Point3D(x: int = 3, y: float = 0, z: float = 3) -> None
Docstring: Point3D(x: int = 3, y: float = 0, z: float = 3)
File: ...
Type: type
Subclasses:
|
メソッドのオーバーライド
dataclass
には、__post_init__
などのメソッドを記述できます。なお、__post_init__
はオブジェクトを作成するときに自動実行される、初期化処理を記述するメソッドです。
dataclass
を継承するとき、メソッドを再記述することでオーバーライドできます。
以下のように、点の原点からの距離(ユークリッド距離)を格納する属性distance
を定義します。さらに、__post_init__
メソッドでdistance
を計算します。
2次元座標と3次元座標で計算方法が変わるため、Point3D
クラスで__post_init__
を再定義します。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
from dataclasses import dataclass, field
@dataclass
class Point2D:
x: float
y: float
distance: float = field(init=False)
def __post_init__(self):
self.distance = (self.x**2 + self.y**2)**0.5
@dataclass
class Point3D(Point2D):
z: float
def __post_init__(self):
self.distance = (self.x**2 + self.y**2 + self.z**2)**0.5
|
以下のように、distance
が正しく計算されています。
1
2
3
4
5
6
7
|
>>> p2 = Point2D(3, 4)
>>> p2.distance
5.0
>>> p3 = Point3D(1, 2, 2)
>>> p3.distance
3.0
|
親クラスのメソッド呼び出し
通常のクラスの継承のように、dataclass
のサブクラスにおいても、親クラスのメソッドをsuper()
を使用して呼び出すことができます。
以下のPoint1D
, Point2D
クラスでは、それぞれ引数を正としたいです。Point1D
クラスでは、__post_init__
メソッドでx
が正であることをチェックしています。
一方、Point2D
クラスでは、super().__post_init__()
を実行することで、スーパークラス (Point1D
) の__post_init__
メソッドを呼び出すことができます。
そのため、Point2D
クラスにはx
の正負をチェックする記述を書き直す必要がなく、y
の正負をチェックする処理を書くだけで済みます。
(この例ではあまりメリットはありませんが、スーパークラス内の処理が長く複雑になるほど恩恵は大きくなります)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
@dataclass
class Point1D:
"""引数xは正とすること"""
x: float
def __post_init__(self):
"""xが正であることをチェック"""
if self.x <= 0:
raise ValueError
@dataclass
class Point2D(Point1D):
"""引数x, yは正とすること"""
y: float
def __post_init__(self):
"""x, yが正であることをチェック"""
super().__post_init__()
if self.y <= 0:
raise ValueError
|
以下にエラーを発生させた例を示します。
サブクラスPoint2D
のx
に-1
を与えると、Point1D
クラスの__post_init__
メソッドでValueError
を発生したことが分かります。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
>>> p1 = Point1D(x=-1) # ValueErrorが発生する。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
----> 1 p1 = Point1D(x=-1)
5 def __post_init__(self):
6 if self.x < 0:
----> 7 raise ValueError
>>> p2 = Point2D(x=-1, y=1) # x=-1に対してValueErrorが発生する。
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
----> 1 p2 = Point2D(x=-1, y=1)
13 def __post_init__(self):
---> 14 super().__post_init__()
16 if self.y < 0:
17 raise ValueError
5 def __post_init__(self):
6 if self.x < 0:
----> 7 raise ValueError
|
継承と型チェック
Pythonでは変数に型ヒントを付けることができます。dataclass
で定義したクラスについても、以下のように型ヒントに出来ます。
1
2
|
def my_func(point: Point2D):
...
|
Pythonの有名な型チェッカとしてmypyがあります。mypyでは、dataclass
の継承について以下のように判定します。
- 型ヒントがスーパークラスのとき、スーパークラスとサブクラスのどちらを与えても問題ない。
- 型ヒントがサブクラスのとき、スーパークラスを与えるとエラーとして検出。
スーパークラスよりもそれを継承したサブクラスの方が一般に属性が多いため、この仕様は自明です。
OKな例とNGな例をそれぞれ示します。
以下はOKな例です。スーパークラスPoint1D
の型ヒントを持つ関数print_1x
に対し、サブクラスPoint2D
のオブジェクトを与えています。
mypyによるチェックを実行してもエラーは検出されませんでした。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
from dataclasses import dataclass
@dataclass
class Point1D:
x: float
@dataclass
class Point2D(Point1D):
y: float
def print_1x(p: Point1D):
print(p.x)
if __name__=="__main__":
p2 = Point2D(1, 2)
print_1x(p2)
|
> mypy test_ok.py
Success: no issues found in 1 source file
次に、以下はNGな例です。サブクラスPoint2D
の型ヒントを持つ関数print_x2
に対し、スーパークラスPoint1D
のオブジェクトを与えています。
mypyによってエラーが検出されています。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
from dataclasses import dataclass
@dataclass
class Point1D:
x: float
@dataclass
class Point2D(Point1D):
y: float
def print_x2(p: Point2D):
print(p.x)
if __name__=="__main__":
p1 = Point1D(1)
print_x2(p1)
|
> mypy test_ng.py
mypy_test_ng.py:16: error: Argument 1 to
"print_x2" has incompatible type "Point1D";
expected "Point2D" [arg-type]
Found 1 error in 1 file (checked 1 source file)
なお、関数print_x2
ではx
属性にしかアクセスしていないため、上記のtest_ng.py
というコードを実行すること自体は可能です。
ただし、将来コードを修正したときにバグを埋め込んでしまう可能性があるため、望ましい状態ではありません。
参考