py: Catch all cases of integer (big and small) division by zero.
diff --git a/py/compile.c b/py/compile.c
index e00af60..797d7a9 100644
--- a/py/compile.c
+++ b/py/compile.c
@@ -293,7 +293,9 @@
// pass
} else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_PERCENT)) {
// int%int
- pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, mp_small_int_modulo(arg0, arg1));
+ if (arg1 != 0) {
+ pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, mp_small_int_modulo(arg0, arg1));
+ }
} else {
assert(MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_DBL_SLASH)); // should be
if (arg1 != 0) {
diff --git a/py/objint_mpz.c b/py/objint_mpz.c
index 69a81d2..73469f3 100644
--- a/py/objint_mpz.c
+++ b/py/objint_mpz.c
@@ -193,6 +193,9 @@
if (0) {
#if MICROPY_PY_BUILTINS_FLOAT
} else if (op == MP_BINARY_OP_TRUE_DIVIDE || op == MP_BINARY_OP_INPLACE_TRUE_DIVIDE) {
+ if (mpz_is_zero(zrhs)) {
+ goto zero_division_error;
+ }
mp_float_t flhs = mpz_as_float(zlhs);
mp_float_t frhs = mpz_as_float(zrhs);
return mp_obj_new_float(flhs / frhs);
@@ -216,6 +219,11 @@
break;
case MP_BINARY_OP_FLOOR_DIVIDE:
case MP_BINARY_OP_INPLACE_FLOOR_DIVIDE: {
+ if (mpz_is_zero(zrhs)) {
+ zero_division_error:
+ nlr_raise(mp_obj_new_exception_msg(&mp_type_ZeroDivisionError,
+ "division by zero"));
+ }
mpz_t rem; mpz_init_zero(&rem);
mpz_divmod_inpl(&res->mpz, &rem, zlhs, zrhs);
if (zlhs->neg != zrhs->neg) {
@@ -229,6 +237,9 @@
}
case MP_BINARY_OP_MODULO:
case MP_BINARY_OP_INPLACE_MODULO: {
+ if (mpz_is_zero(zrhs)) {
+ goto zero_division_error;
+ }
mpz_t quo; mpz_init_zero(&quo);
mpz_divmod_inpl(&quo, &res->mpz, zlhs, zrhs);
mpz_deinit(&quo);
@@ -274,6 +285,9 @@
break;
case MP_BINARY_OP_DIVMOD: {
+ if (mpz_is_zero(zrhs)) {
+ goto zero_division_error;
+ }
mp_obj_int_t *quo = mp_obj_int_new_mpz();
mpz_divmod_inpl(&quo->mpz, &res->mpz, zlhs, zrhs);
// Check signs and do Python style modulo
diff --git a/py/runtime.c b/py/runtime.c
index 69ac754..98cde83 100644
--- a/py/runtime.c
+++ b/py/runtime.c
@@ -389,6 +389,9 @@
case MP_BINARY_OP_MODULO:
case MP_BINARY_OP_INPLACE_MODULO: {
+ if (rhs_val == 0) {
+ goto zero_division;
+ }
lhs_val = mp_small_int_modulo(lhs_val, rhs_val);
break;
}
diff --git a/tests/basics/builtin_divmod.py b/tests/basics/builtin_divmod.py
index e3eff30..c3b8658 100644
--- a/tests/basics/builtin_divmod.py
+++ b/tests/basics/builtin_divmod.py
@@ -10,6 +10,11 @@
print("ZeroDivisionError")
try:
+ divmod(1 << 65, 0)
+except ZeroDivisionError:
+ print("ZeroDivisionError")
+
+try:
divmod('a', 'b')
except TypeError:
print("TypeError")
diff --git a/tests/basics/int_big_error.py b/tests/basics/int_big_error.py
index b7875ee..e036525 100644
--- a/tests/basics/int_big_error.py
+++ b/tests/basics/int_big_error.py
@@ -29,3 +29,13 @@
i << (-(i >> 40))
except ValueError:
print('ValueError')
+
+try:
+ i // 0
+except ZeroDivisionError:
+ print('ZeroDivisionError')
+
+try:
+ i % 0
+except ZeroDivisionError:
+ print('ZeroDivisionError')
diff --git a/tests/basics/int_divzero.py b/tests/basics/int_divzero.py
index 28ec2a6..aa38eee 100644
--- a/tests/basics/int_divzero.py
+++ b/tests/basics/int_divzero.py
@@ -2,3 +2,8 @@
1 // 0
except ZeroDivisionError:
print("ZeroDivisionError")
+
+try:
+ 1 % 0
+except ZeroDivisionError:
+ print("ZeroDivisionError")
diff --git a/tests/float/int_big_float.py b/tests/float/int_big_float.py
index 2c40418..b1a26ca 100644
--- a/tests/float/int_big_float.py
+++ b/tests/float/int_big_float.py
@@ -17,3 +17,8 @@
# this should delegate to complex
print("%.5g" % (i * 1.2j).imag)
+
+try:
+ i / 0
+except ZeroDivisionError:
+ print("ZeroDivisionError")