Publication Cover
Optimization
A Journal of Mathematical Programming and Operations Research
Latest Articles
141
Views
0
CrossRef citations to date
0
Altmetric
Research Article

Iteration and stochastic first-order oracle complexities of stochastic gradient descent using constant and decaying learning rates

ORCID Icon & ORCID Icon
Received 23 Feb 2024, Accepted 08 Jun 2024, Published online: 19 Jun 2024

Abstract

The performance of stochastic gradient descent (SGD), which is the simplest first-order optimizer for training deep neural networks, depends on not only the learning rate but also the batch size. They both affect the number of iterations and the stochastic first-order oracle (SFO) complexity needed for training. In particular, the previous numerical results indicated that, for SGD using a constant learning rate, the number of iterations needed for training decreases when the batch size increases, and the SFO complexity needed for training is minimized at a critical batch size and that it increases once the batch size exceeds that size. Here, we study the relationship between batch size and the iteration and SFO complexities needed for nonconvex optimization in deep learning with SGD using constant or decaying learning rates and show that SGD using the critical batch size minimizes the SFO complexity. We also provide numerical comparisons of SGD with the existing first-order optimizers and show the usefulness of SGD using a critical batch size. Moreover, we show that measured critical batch sizes are close to the sizes estimated from our theoretical results.

2020 Mathematics Subject Classifications:

1. Introduction

1.1. Background

First-order optimizers can train deep neural networks by minimizing loss functions called the expected and empirical risks. They use stochastic first-order derivatives (stochastic gradients), which are estimated from the full gradient of the loss function. The simplest first-order optimizer is stochastic gradient descent (SGD) [Citation1–5], which has a number of variants, including momentum variants [Citation6, Citation7] and numerous adaptive variants, such as adaptive gradient (AdaGrad) [Citation8], root mean square propagation (RMSProp) [Citation9], adaptive moment estimation (Adam) [Citation10], adaptive mean square gradient (AMSGrad) [Citation11], and Adam with decoupled weight decay (AdamW) [Citation12].

SGD can be applied to nonconvex optimization [Citation13–22], where its performance strongly depends on the learning rate αk. For example, under the bounded variance assumption, SGD using a constant learning rate αk=α satisfies 1Kk=0K1∇f(θk)2=O(1K)+σ2 [Citation18, Theorem 12] and SGD using a decaying learning rate (i.e.αk0) satisfies that 1Kk=0K1E[∇f(θk)2]=O(1K) [Citation18, Theorem 11], where (θk)kN is the sequence generated by SGD to find a local minimizer of f, K is the number of iterations, and σ2 is the upper bound of the variance.

The performance of SGD also depends on the batch size b. The convergence analyses reported in [Citation14, Citation17, Citation21, Citation23, Citation24] indicated that SGD with a decaying learning rate and large batch size converges to a local minimizer of the loss function. In [Citation25], it was numerically shown that using an enormous batch reduces both the number of parameter updates and model training time. Moreover, setting appropriate batch sizes for optimizers used in training generative adversarial networks were investigated in [Citation26].

1.2. Motivation

The previous numerical results in [Citation27] indicated that, for SGD using constant or linearly decaying learning rates, the number of iterations K needed to train a deep neural network decreases as the batch size b increases. Motivated by the numerical results in [Citation27], we decided to clarify the theoretical iteration complexity of SGD with a constant or decaying learning rate in training a deep neural network. We used the performance measure of previous theoretical analyses of SGD, i.e. mink[0:K1]E[∇f(θk)]ϵ, where ϵ (>0) is the precision and [0:K1]:={0,1,,K1}. We found that, if SGD is an ϵ–approximation, i.e. mink[0:K1]E[∇f(θk)]ϵ, then it can train a deep neural network in K iterations.

In addition, the numerical results in [Citation27] indicated an interesting fact wherein diminishing returns exist beyond a critical batch size; i.e. the number of iterations needed to train a deep neural network does not strictly decrease beyond the critical batch size. Here, we define the stochastic first-order oracle (SFO) complexity as N := Kb, where K is the number of iterations needed to train a deep neural network and b is the batch size, as stated above. The deep neural network model uses b gradients of the loss functions per iteration. The model has a stochastic gradient computation cost of N = Kb. From the numerical results in [Citation27, Figures 4 and 5], we can conclude that the critical batch size b (if it exists) is useful for SGD, since the SFO complexity N(b) is minimized at b=b and the SFO complexity increases once the batch size exceeds b. Hence, on the basis of the first motivation stated above, we decided to clarify the SFO complexities needed for SGD using a constant or decaying learning rate to be an ϵ–approximation.

1.3. Contribution

1.3.1. Upper bound of theoretical performance measure

To clarify the iteration and SFO complexities needed for SGD to be an ϵ–approximation, we first give upper bounds of mink[0:K1]E[∇f(θk)2] for SGD to generate a sequence (θk)kN with constant or decaying learning rates (see Theorem 3.1 for the definitions of Ci and Di). As our aim is to show that SGD is an ϵ–approximation mink[0:K1]E[∇f(θk)2]ϵ2, it is desirable that the upper bounds of mink[0:K1]E[∇f(θk)2] be small. Table  indicates that the upper bounds become small when the number of iterations and batch size are large. The table also indicates that the convergence of SGD strongly depends on the batch size, since the variance terms (including σ2 and b; see Theorem 3.1 for the definitions of C2 and D2) in the upper bounds of mink[0:K1]E[∇f(θk)2] decrease as the batch size becomes larger.

Table 1. Upper bounds of mink[0:K1]E[∇f(θk)2] for SGD using a constant or decaying learning rate and the critical batch size to minimize the SFO complexities and achieve mink[0:K1]E[∇f(θk)]ϵ (Ci and Di are positive constants, K is the number of iterations, b is the batch size, T1, ϵ>0, and L is the Lipschitz constant of ∇f).

1.3.2. Critical batch size to reduce SFO complexity

Section 1.3.1 showed that using large batches is appropriate for SGD in the sense of minimizing the upper bound of the performance measure. Here, we are interested in finding appropriate batch sizes from the viewpoint of the computation cost. This is because the SFO complexity increases with the batch size. As indicated in Section 1.2, the critical batch size b minimizes the SFO complexity, N = Kb. Hence, we will investigate the properties of the SFO complexity N = Kb needed to achieve an ϵ–approximation. Here, let us consider SGD using a constant learning rate. From the “Upper Bound” row in Table , we have mink[0:K1]E[∇f(θk)2]C1K+C2bϵ2KK(b):=C1bϵ2bC2 (b>C2ϵ2). We can check that the number of iterations, K(b):=C1bϵ2bC2, needed to achieve an ϵ–approximation is monotone decreasing and convex with respect to the batch size (Theorem 3.2). Accordingly, we have that K(b)inf{K:mink[0:K1]E[∇f(θk)]ϵ}, where SGD using the batch size b generates (θk)k=0K1. Moreover, we find that the SFO complexity is N(b)=K(b)b=C1b2ϵ2bC2. The convexity of N(b)=C1b2ϵ2bC2 (Theorem 3.3) ensures that a critical batch size b=2C2ϵ2 whereby N(b)=0 exists such that N(b) is minimized at b (see the “Critical Batch Size” row in Table ). A similar discussion guarantees the existence of a critical batch size for SGD using a decaying learning rate αk=1(kT+1)a, where T1, a(0,12)(12,1), and is the floor function (see the “Critical Batch Size” row in Table ).

Meanwhile, for a decaying learning rate αk=1kT+1, although N(b) is convex with respect to b, we have that N(b)>0 for all b>D2ϵ2 (Theorem 3.3(iii)). Hence, for this case, a critical batch size b defined by N(b)=0 does not exist. However, since the critical batch size minimizes the SFO complexity N, we can define one as follows: bD2ϵ2. Accordingly, we have that N(b)inf{Kb:mink[0:K1]E[∇f(θk)]ϵ}, where SGD using b generates (θk)k=0K1.

1.3.3. Iteration and SFO complexities

Let F(n,Δ0,L) be an L–smooth function class with f:=1ni=1nfi and f(θ0)fΔ0 (see (C1)) and let O(b,σ2) be a stochastic first-order oracle class (see (C2) and (C3)). The iteration complexity Kϵ [Citation21, (7)] and SFO complexity Nϵ needed for SGD to be an ϵ–approximation are defined as (1) Kϵ(n,b,αk,Δ0,L,σ2):=supOO(b,σ2)supfF(n,Δ0,L)inf{K:mink[0:K1]E[∇f(θk)]ϵ},(1) (2) Nϵ(n,b,αk,Δ0,L,σ2):=Kϵ(n,b,αk,Δ0,L,σ2)b.(2) Table  summarizes the iteration and SFO complexities (see also Theorem 3.4). Corollaries 6 and 7 in [Citation14] are the same as our results for SGD with a constant learning rate in Theorems 3.1 and 3.3, since the randomized stochastic projected gradient free algorithm in [Citation14] which is a stochastic zeroth-order (SZO) method that coincides with SGD and it can be applied to the situations where only noisy function values are available. In particular, Corollary 6 in [Citation14] gave the convergence rate of the SZO methods using a fixed batch size, and Corollary 7 indicated the SZO complexity of the SZO method is the same as the SFO complexity. Hence, Corollaries 6 and 7 in [Citation14] lead to the finding that the iteration complexity of SGD using a constant learning rate is O(1/ϵ2) and the SFO complexity of SGD using a constant learning rate is O(1/ϵ4).

Table 2. Iteration and SFO complexities needed for SGD using a constant or decaying learning rate to be an ϵ–approximation (The critical batch sizes are used to compute Kϵ and Nϵ)

Since the positive constants Ci and Di depend on the learning rate, we need to compare numerically the performance of SGD with a constant learning rate with that of SGD with a decaying learning rate. Moreover, we also need to compare SGD with the existing first-order optimizers in order to verify its usefulness. Section 4 presents numerical comparisons showing that SGD using the critical batch size outperforms the existing first-order optimizers. We also show that the measured critical batch sizes are close to the theoretical sizes.

2. Nonconvex optimization and SGD

2.1. Nonconvex optimization in deep learning

Let Rd be a d-dimensional Euclidean space with inner product x,y:=xy inducing the norm x and N be the set of nonnegative integers. Define [0:n]:={0,1,,n} for n1. Let (xk)kN and (yk)kN be positive real sequences and let x(ϵ),y(ϵ)>0, where ϵ>0. O denotes Landau's symbol; i.e. yk=O(xk) if there exist c>0 and k0N such that ykcxk for all kk0, and y(ϵ)=O(x(ϵ)) if there exists c>0 such that y(ϵ)cx(ϵ). Given a parameter θRd and a data point z in a data domain Z, a machine-learning model provides a prediction whose quality can be measured in terms of a differentiable nonconvex loss function (θ;z). We aim to minimize the empirical loss defined for all θRd by f(θ)=1ni=1n(θ;zi)=1ni=1nfi(θ), where S=(z1,z2,,zn) denotes the training set (We assume that the number of training data n is large) and fi():=(;zi) denotes the loss function corresponding to the i-th training data zi.

2.2. SGD

2.2.1. Conditions and algorithm

We assume that a stochastic first-order oracle (SFO) exists such that, for a given θRd, it returns a stochastic gradient Gξ(θ) of the function f, where a random variable ξ is independent of θ. Let Eξ[] be the expectation taken with respect to ξ. The following are standard conditions.

  • (C1) f:=1ni=1nfi:RdR is L–smooth, i.e. ∇f:RdRd is L–Lipschitz continuous (i.e. ∇f(x)∇f(y)Lxy). f is bounded below from fR. Let Δ0>0 satisfy f(θ0)fΔ0, where θ0 is an initial point.

  • (C2) Let (θk)kNRd be the sequence generated by SGD. For each iteration k, Eξk[Gξk(θk)]=∇f(θk), where ξ0,ξ1, are independent samples and the random variable ξk is independent of (θl)l=0k. There exists a nonnegative constant σ2 such that Eξk[Gξk(θk)∇f(θk)2]σ2.

  • (C3) For each iteration k, SGD samples a batch Bk of size b independently of k and estimates the full gradient ∇f as fBk(θk):=1bi[b]Gξk,i(θk), where bn and ξk,i is a random variable generated by the ith sampling in the kth iteration.

Algorithm 1 is the SGD optimizer under (C1)–(C3).

2.2.2. Learning rates

We use the following learning rates:

(Constant rate) αk does not depend on kN, i.e. αk=α<2L (kN), where the upper bound 2L of α is needed to analyse SGD (see Appendix A.2). (Decaying rate) (αk)kN(0,+) is monotone decreasing for k (i.e. αkαk+1) and converges to 0. In particular, we will use αk=1(kT+1)a, where T1 and (Decay1)a(0,12)(Decay2)a=12(Decay3)a(12,1). It is guaranteed that there exists k0N such that, for all kk0, αk<2L. Furthermore, we assume that k0=0, since we can replace αk with α(kT+1)aα<2L (kN), where α(0,2L) is defined as in (Constant).

3. Our results

3.1. Upper bound of the squared norm of the full gradient

Here, we give an upper bound of mink[0:K1]E[∇f(θk)2], where E[] stands for the total expectation, for the sequence generated by SGD using each of the learning rates defined in Section 2.2.2.

Theorem 3.1

Upper bound of the squared norm of the full gradient

The sequence (θk)kN generated by Algorithm 1 under (C1)(C3) satisfies that, for all K1, mink[0:K1]E[∇f(θk)2]{C1K+C2b(Constant)D1Ka+D2(12a)Kab(Decay 1)D1K+(1K+1)D2b(Decay 2)D1K1a+2aD2(2a1)K1ab(Decay 3) where C1:=2(f(θ0)f)(2)α,C2:=Lσ2α2,D1:=2(f(θ0)f)α(2),D2:=Tα2Lσ22.

Theorem 3.1 indicates that the upper bound of mink[0:K1]E[∇f(θk)2] consists of a bias term including f(θ0)f and a variance term including σ2 and that these terms become small when the number of iterations and the batch size are large. In particular, the bias term using (Constant) is O(1K), which is a better rate than using (Decay 1)(Decay 3).

3.2. Number of iterations needed for SGD to be an ϵ–approximation

Let us suppose that SGD is an ϵ–approximation defined as follows: (3) E[∇f(θK)2]:=mink[0:K1]E[∇f(θk)2]ϵ2,(3) where ϵ>0 is the precision and K[0:K1]. Condition (Equation3) implies that E[∇f(θK)]ϵ. Theorem 3.1 below gives the number of iterations needed to be an ϵ–approximation (Equation3).

Theorem 3.2

Numbers of iterations needed for nonconvex optimization using SGD

Let (θk)kN be the sequence generated by Algorithm 1 under (C1)(C3) and let K:RR be K(b)={C1bϵ2bC2(Constant){1ϵ2(D2(12a)b+D1)}1a(Decay 1)(D1b+D2ϵ2bD2)2(Decay 2){1ϵ2(2aD2(2a1)b+D1)}11a(Decay 3) where C1, C2, D1, and D2 are defined as in Theorem 3.1, the domain of K in (Constant) is b>C2ϵ2, and the domain of K in (Decay 2) is b>D2ϵ2. Then, we have the following:

  1. The above K achieves an ϵ–approximation (Equation3).

  2. The above K is a monotone decreasing and convex function with respect to the batch size b.

Theorem 3.2 indicates that the number of iterations needed for SGD using constant or decay learning rates to be an ϵ–approximation is small when the batch size is large. Hence, it is appropriate to set a large batch size in order to minimize the iterations needed for an ϵ–approximation (Equation3). However, the SFO complexity, which is the cost of the stochastic gradient computation, grows larger with b. Hence, the appropriate batch size should also minimize the SFO complexity.

3.3. SFO complexity needed for SGD to be an ϵ–approximation

Theorem 3.2 leads to the following theorem on the properties of the SFO complexity N needed for SGD to be an ϵ–approximation (Equation3).

Theorem 3.3

SFO complexity needed for nonconvex optimization of SGD

Let (θk)kN be the sequence generated by Algorithm 1 under (C1)(C3) and define N:RR by N(b)=K(b)b={C1b2ϵ2bC2(Constant){1ϵ2(D2(12a)b+D1)}1ab(Decay 1)(D1b+D2ϵ2bD2)2b(Decay 2){1ϵ2(2aD2(2a1)b+D1)}11ab(Decay 3) where C1, C2, D1 and D2 are as in Theorem 3.1, the domain of N in (Constant) is b>C2ϵ2, and the domain of N in (Decay 2) is b>D2ϵ2. Then, we have the following:

  1. The above N is convex with respect to the batch size b.

  2. There exists a critical batch size (4) b={2C2ϵ2(Constant)(1a)D2a(12a)D1(Decay 1)2a2D2(1a)(2a1)D1(Decay 3)(4) satisfying N(b)=0 such that b minimizes the SFO complexity N.

  3. For (Decay 2), N(b)>0 holds for all b>D2ϵ2.

Theorem 3.3(ii) indicates that, if we can set a critical batch size (Equation4) for each of (Constant), (Decay 1), and (Decay 3), then the SFO complexity will be minimized. However, it would be difficult to set b in (Equation4) before implementing SGD, since b in (Equation4) involves unknown parameters, such as L and σ2 (computing L is NP-hard [Citation28]). Hence, we would like to estimate the critical batch sizes by using Theorem 3.3(ii) (see Section 4.3). Theorem 3.3(ii) indicates that the smaller ϵ is, the larger the critical batch size b in (Constant) becomes. Theorem 3.3(iii) indicates that the critical batch size is close to D2ϵ2 when using (Decay 2) to minimize the SFO complexity N.

3.4. Iteration and SFO complexities of SGD

Theorems 3.2 and 3.3 lead to the following theorem indicating the iteration and SFO complexities needed for SGD to be an ϵ–approximation (see also Table ).

Theorem 3.4

Iteration and SFO complexities of SGD

The iteration and SFO complexities such that Algorithm 1 under (C1)(C3) is an ϵ–approximation (Equation3) are as follows: (Kϵ(n,b,αk,Δ0,L,σ2),Nϵ(n,b,αk,Δ0,L,σ2))={(O(1ϵ2),O(1ϵ4))(Constant)(O(1ϵ2a),O(1ϵ2a))(Decay 1)(O(1ϵ4),O(1ϵ6))(Decay 2)(O(1ϵ21a),O(1ϵ21a))(Decay 3) where Kϵ(n,b,αk,Δ0,L,σ2) and Nϵ(n,b,αk,Δ0,L,σ2) are defined as in (Equation1), the critical batch sizes in Theorem 3.3 are used to compute Kϵ(n,b,αk,Δ0,L,σ2) and Nϵ(n,b,αk,Δ0,L,σ2). In (Decay 2), we assume that b=D2+1ϵ2. (see also (Equation4)).

Theorem 3.4 indicates that the iteration and SFO complexities for (Constant) are smaller than those for (Decay 1)(Decay 3).

4. Numerical results

We numerically verified the number of iterations and SFO complexities needed to achieve high test accuracy for different batch sizes in training ResNet [Citation29] and Wide-ResNet [Citation30]. The parameter α used in (Constant) was determined by conducting a grid search of {0.001,0.005,0.01,0.05,0.1,0.5}. The parameters α and T used in the decaying learning rates (Decay 1)(Decay 3) defined by αk=α(kT+1)a were determined by a grid search of α{0.001,0.1,0.125,0.25,0.5,1.0} and T{5,10,20,30,40,50}. The parameter a was set to a=14 in (Decay 1) and a=34 in (Decay 3). We compared SGD with SGD with momentum (momentum), Adam, AdamW, and RMSProp. The learning rates and hyperparameters of these four optimizers were determined on the basis of the previous results [Citation9, Citation10, Citation12] (The weight decay used in the momentum was 5×104). The experimental environment consisted of an NVIDIA DGX A100×8GPU and Dual AMD Rome7742 2.25-GHz, 128 Cores×2CPU. The software environment was Python 3.10.6, PyTorch 1.13.1, and CUDA 11.6. The code is available at https://github.com/imakn0907/SGD_using_decaying.

4.1. Training ResNet-18 on the CIFAR-10 and CIFAR-100 datasets

First, we trained ResNet-18 on the CIFAR-10 dataset. The stopping condition of the optimizers was 200 epochs. Figure  indicates that the number of iterations is monotone decreasing and convex with respect to batch size for SGDs using a constant learning rate or a decaying learning rate. Figure  indicates that, in each case of SGD with (Constant)(Decay 3), a critical batch size b=24 exists at which the SFO complexity is minimized.

Figure 1. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (ResNet-18 on CIFAR-10).

Figure 1. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (ResNet-18 on CIFAR-10).

Figure 2. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (ResNet-18 on CIFAR-10).

Figure 2. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (ResNet-18 on CIFAR-10).

Figures  and  indicate that the number of iterations and the SFO complexity for four different learning rates in achieving a test accuracy of 0.6 when training ResNet-18 on the CIFAR-100 dataset. The figures indicate that critical batch sizes existed when using (Constant)(Decay 3).

Figure 3. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figure 3. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figure 4. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figure 4. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figures  and  compare SGD with (Decay 1) with the other optimizers in training ResNet-18 on the CIFAR-100 dataset. These figures indicates that SGD with (Decay 1) and a critical batch size (b=24) outperformed the other optimizers in the sense of minimizing the number of iterations and the SFO complexity. Figure also indicates that the existing optimizers using constant learning rates had critical batch sizes minimizing the SFO complexities. In particular, AdamW using the critical batch size b=25 performed well.

Figure 5. Number of iterations needed for SGD with (Decay 1), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figure 5. Number of iterations needed for SGD with (Decay 1), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figure 6. SFO complexity needed for SGD with (Decay 1), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

Figure 6. SFO complexity needed for SGD with (Decay 1), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.6 versus batch size (ResNet-18 on CIFAR-100).

4.2. Training Wide-ResNet on the CIFAR-10 and CIFAR-100 datasets

Next, we trained Wide-ResNet-28 [Citation30] on the CIFAR-10 and CIFAR-100 datasets. The stopping condition of the optimizers was 200 epochs. Figures  and  show that the number of iterations and the SFO complexity of SGD to achieve a test accuracy of 0.9 (CIFAR-10) versus batch size. Figures  and indicate that the critical batch size was b=24 in each case of SGD using (Constant)(Decay 3).

Figure 7. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (Wide-ResNet-28-10 on CIFAR-10).

Figure 7. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (Wide-ResNet-28-10 on CIFAR-10).

Figure 8. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (Wide-ResNet-28-10 on CIFAR-10).

Figure 8. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.9 versus batch size (Wide-ResNet-28-10 on CIFAR-10).

Figures  and  indicate that the number of iterations and the SFO complexity of SGD to achieve a test accuracy of 0.6 (CIFAR-100) versus batch size and show that a critical batch size existed for (Constant)(Decay 3).

Figure 9. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (WideResNet-28-12 on CIFAR-100).

Figure 9. Number of iterations needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (WideResNet-28-12 on CIFAR-100).

Figure 10. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (WideResNet-28-12 on CIFAR-100).

Figure 10. SFO complexity needed for SGD with (Constant), (Decay 1), (Decay 2), and (Decay 3) to achieve a test accuracy of 0.6 versus batch size (WideResNet-28-12 on CIFAR-100).

Figures and indicate that SGD using (Decay 1) and the existing optimizers using constant learning rates had critical batch sizes minimizing the SFO complexities. Also, Figures and  show similar results for SGD using (Decay 3).

Figure 11. Number of iterations needed for SGD with (Decay 3), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.9 versus batch size (WideResNet-28-10 on CIFAR-10).

Figure 11. Number of iterations needed for SGD with (Decay 3), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.9 versus batch size (WideResNet-28-10 on CIFAR-10).

Figure 12. SFO complexity needed for SGD with (Decay 3), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.9 versus batch size (WideResNet-28-10 on CIFAR-10).

Figure 12. SFO complexity needed for SGD with (Decay 3), momentum, Adam, AdamW, and RMSProp to achieve a test accuracy of 0.9 versus batch size (WideResNet-28-10 on CIFAR-10).

4.3. Estimation of critical batch sizes

We estimated the critical batch sizes of (Decay 3) using Theorem 3.3 and measured the critical batch sizes of (Decay 1). From (Equation4) and a=14 ((Decay 1)), we have that, for training ResNet-18 on the CIFAR-10 dataset, b=24=(1a)D2a(12a)D1, i.e. D2D1=83. Then, the estimated critical batch size of SGD using (Decay 3) (a=34) for training ResNet-18 on the CIFAR-10 dataset is b=2a2(1a)(2a1)D2D1=2a2(1a)(2a1)83=24(24,25), which implies that the estimated critical batch size b=24 is close to the measured size b=24. We also found that the estimated critical batch sizes are close to the measured critical batch sizes (see Table ).

Table 3. Measured (left) and estimated (right; bold) critical batch sizes (D1 and D3 stand for (Decay 1) and (Decay 3)).

5. Conclusion and future work

This paper investigated the required number of iterations and SFO complexities for SGD using constant or decay learning rates to achieve an ϵ–approximation. Our theoretical analyses indicated that the number of iterations needed for an ϵ–approximation is monotone decreasing and convex with respect to the batch size and the SFO complexity needed for an ϵ–approximation is convex with respect to the batch size. Moreover, we showed that SGD using a critical batch size reduces the SFO complexity. The numerical results indicated that SGD using the critical batch size performs better than the existing optimizers in the sense of minimizing the SFO complexity. We also estimated critical batch sizes of SGD using our theoretical results and showed that they are close to the measured critical batch sizes.

The results in this paper can be only applied to SGD. This is a limitation of our work. Hence, in the future, we should investigate whether our results can be applied to variants of SGD, such as the momentum methods and adaptive methods.

Disclosure statement

No potential conflict of interest was reported by the author(s).

References

  • Robbins H, Monro H. A stochastic approximation method. Ann Math Stat. 1951;22:400–407. doi: 10.1214/aoms/1177729586
  • Zinkevich M. Online convex programming and generalized infinitesimal gradient ascent. In: Proceedings of the 20th International Conference on Machine Learning; Washington, DC, USA; 2003. p. 928–936.
  • Nemirovski A, Juditsky A, Lan G, et al. Robust stochastic approximation approach to stochastic programming. SIAM J Optim. 2009;19:1574–1609. doi: 10.1137/070704277
  • Ghadimi S, Lan G. Optimal stochastic approximation algorithms for strongly convex stochastic composite optimization I: A generic algorithmic framework. SIAM J Optim. 2012;22:1469–1492. doi: 10.1137/110848864
  • Ghadimi S, Lan G. Optimal stochastic approximation algorithms for strongly convex stochastic composite optimization II: shrinking procedures and optimal algorithms. SIAM J Optim. 2013;23:2061–2089. doi: 10.1137/110848876
  • Polyak BT. Some methods of speeding up the convergence of iteration methods. USSR Comput Math Math Phys. 1964;4:1–17. doi: 10.1016/0041-5553(64)90137-5
  • Nesterov Y. A method for unconstrained convex minimization problem with the rate of convergence O(1/k2). Doklady AN USSR. 1983;269:543–547.
  • Duchi J, Hazan E, Singer Y. Adaptive subgradient methods for online learning and stochastic optimization. J Mach Learn Res. 2011;12:2121–2159.
  • Tieleman T, Hinton G. RMSProp: divide the gradient by a running average of its recent magnitude. COURSERA Neural Netw Mach Learn. 2012;4:26–31.
  • Kingma DP, Ba J. Adam: a method for stochastic optimization. In: Proceedings of The International Conference on Learning Representations; San Diego, CA, USA; 2015.
  • Reddi SJ, Kale S, Kumar S. On the convergence of Adam and beyond. In: Proceedings of The International Conference on Learning Representations; Vancouver, British Columbia, Canada; 2018.
  • Loshchilov I, Hutter F. Decoupled weight decay regularization. In: Proceedings of The International Conference on Learning Representations; New Orleans, Louisiana, USA; 2019.
  • Ghadimi S, Lan G. Stochastic first- and zeroth-order methods for nonconvex stochastic programming. SIAM J Optim. 2013;23(4):2341–2368. doi: 10.1137/120880811
  • Ghadimi S, Lan G, Zhang H. Mini-batch stochastic approximation methods for nonconvex stochastic composite optimization. Math Program. 2016;155(1):267–305. doi: 10.1007/s10107-014-0846-1
  • Vaswani S, Mishkin A, Laradji I, et al. Painless stochastic gradient: interpolation, line-search, and convergence rates. In: Advances in Neural Information Processing Systems; Vancouver, British Columbia, Canada; Vol. 32; 2019.
  • Fehrman B, Gess B, Jentzen A. Convergence rates for the stochastic gradient descent method for non-convex objective functions. J Mach Learn Res. 2020;21:1–48.
  • Chen H, Zheng L, AL Kontar R, et al. Stochastic gradient descent in correlated settings: a study on Gaussian processes. In: Advances in Neural Information Processing Systems; Virtual conference, Vol. 33; 2020.
  • Scaman K, Malherbe C. Robustness analysis of non-convex stochastic gradient descent using biased expectations. In: Advances in Neural Information Processing Systems; Virtual conference, Vol. 33; 2020.
  • Loizou N, Vaswani S, Laradji I, et al. Stochastic polyak step-size for SGD: an adaptive learning rate for fast convergence. In: Proceedings of the 24th International Conference on Artificial Intelligence and Statistics; Virtual conference, Vol. 130; 2021.
  • Wang X, Magnússon S, Johansson M. On the convergence of step decay step-size for stochastic optimization. In: Beygelzimer A, Dauphin Y, Liang P, et al., editors. Advances in Neural Information Processing Systems; Virtual conference, 2021. Available from: https://openreview.net/forum?id=M-W0asp3fD.
  • Arjevani Y, Carmon Y, Duchi JC, et al. Lower bounds for non-convex stochastic optimization. Math Program. 2023;199(1):165–214. doi: 10.1007/s10107-022-01822-7
  • Khaled A, Richtárik P. Better theory for SGD in the nonconvex world. Trans Mach Learn Res. 2023.
  • Jain P, Kakade SM, Kidambi R, et al. Parallelizing stochastic gradient descent for least squares regression: mini-batching, averaging, and model misspecification. J Mach Learn Res. 2018;18(223):1–42.
  • Cotter A, Shamir O, Srebro N, et al. Better mini-batch algorithms via accelerated gradient methods. In: Advances in Neural Information Processing Systems; Granada, Spain; Vol. 24; 2011.
  • Smith SL, Kindermans PJ, Le QV. Don't decay the learning rate, increase the batch size. In: International Conference on Learning Representations; Vancouver, British Columbia, Canada; 2018.
  • Sato N, Iiduka H. Existence and estimation of critical batch size for training generative adversarial networks with two time-scale update rule. In: Proceedings of the 40th International Conference on Machine Learning; (Proceedings of Machine Learning Research; Vol. 202). PMLR; 2023. p. 30080–30104.
  • Shallue CJ, Lee J, Antognini J, et al. Measuring the effects of data parallelism on neural network training. J Mach Learn Res. 2019;20:1–49.
  • Virmaux A, Scaman K. Lipschitz regularity of deep neural networks: analysis and efficient estimation. In: Advances in Neural Information Processing Systems; Vol. 31; 2018.
  • He K, Zhang X, Ren S, et al. Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition; 2016. p. 770–778.
  • Zagoruyko S, Komodakis N. Wide residual networks. arXiv preprint arXiv:160507146. 2016.

Appendix

A.1. Lemma

First, we will prove the following lemma.

Lemma A.1

The sequence (θk)kN generated by Algorithm 1 under (C1)(C3) satisfies that, for all K1, k=0K1αk(1Lαk2)E[∇f(θk)2]E[f(θ0)f]+Lσ22bk=0K1αk2, where E stands for the total expectation.

Proof.

Condition (C1) (L-smoothness of f) implies that the descent lemma holds, i.e. for all kN, f(θk+1)f(θk)+∇f(θk),θk+1θk+L2θk+1θk2, which, together with θk+1:=θkαkfBk(θk), implies that (A1) f(θk+1)f(θk)αk∇f(θk),fBk(θk)+Lαk22fBk(θk)2.(A1) Condition (C2) guarantees that (A2) Eξk[fBk(θk)|θk]=∇f(θk)andEξk[fBk(θk)∇f(θk)2|θk]σ2b.(A2) Hence, we have (A3) Eξk[fBk(θk)2|θk]=Eξk[fBk(θk)∇f(θk)+∇f(θk)2|θk]=Eξk[fBk(θk)∇f(θk)2|θk]+2Eξk[fBk(θk)∇f(θk),∇f(θk)|θk]+Eξk[∇f(θk)2|θk]σ2b+Eξk[∇f(θk)2].(A3) Taking the expectation conditioned on θk on both sides of (EquationA1), together with (EquationA2) and (EquationA3), guarantees that, for all kN, Eξk[f(θk+1)|θk]f(θk)αkEξk[∇f(θk),fBk(θk)|θk]+Lαk22Eξk[fBk(θk)2|θk]f(θk)αk∇f(θk)2+Lαk22(σ2b+∇f(θk)2). Hence, taking the total expectation on both sides of the above inequality ensures that, for all kN, αk(1Lαk2)E[∇f(θk)2]E[f(θk)f(θk+1)]+Lσ2αk22b. Let K1. Summing the above inequality from k = 0 to k = K−1 ensures that k=0K1αk(1Lαk2)E[∇f(θk)2]E[f(θ0)f(θK)]+Lσ22bk=0K1αk2, which, together with (C1) (the lower bound f of f), implies that the assertion in Lemma A.1 holds.

A.2. Proof of Theorem 3.1

(Constant): Lemma A.1 with αk=α implies that α(12)k=0K1E[∇f(θk)2]E[f(θ0)f]+Lσ2α2K2b. Since α<2L, we have that mink[0:K1]E[∇f(θk)2]1Kk=0K1E[∇f(θk)2]2(f(θ0)f)(2)αC11K+Lσ2α2C21b. (Decay): Since (αk)kN converges to 0, there exists k0N such that, for all kk0, αk<2L. We assume that k0=0 (see Section 2.2.2). Lemma A.1 ensures that, for all K1, k=0K1αk(1Lαk2)E[∇f(θk)2]E[f(θ0)f]+Lσ22bk=0K1αk2, which, together with αk+1αk<2L (kN), implies that αK1(1Lα02)k=0K1E[∇f(θk)2]E[f(θ0)f]+Lσ22bk=0K1αk2. Hence, we have that k=0K1E[∇f(θk)2]2(f(θ0)f)(2Lα0)αK1+Lσ2b(2Lα0)αK1k=0K1αk2, which implies that mink[0:K1]E[∇f(θk)2]2(f(θ0)f)2Lα01KαK1+1bLσ22Lα01KαK1k=0K1αk2. Meanwhile, we have that k=0K1αk2k=0K1Tα2(k+1)2aTα2(1+0K1dt(t+1)2a){Tα212aK12a(Decay 1)Tα2(1+logK)(Decay 2)2aTα22a1(Decay 3) and αK1=α(K1T+1)aα(K1T+1)aαKa. Here, we define D1:=2(f(θ0)f)α(2)andD2:=Tα2Lσ22. Accordingly, we have that mink[0:K1]E[∇f(θk)2]{D1K1a+D2(12a)Kab(Decay 1)D1K+D2(1+logK)Kb(Decay 2)D1K1a+2aD2(2a1)K1ab(Decay 3) which, together with logK<K and the condition on a, implies that mink[0:K1]E[∇f(θk)2]{D1Ka+D2(12a)Kab(Decay 1)D1K+(1K+1)D2b(Decay 2)D1K1a+2aD2(2a1)K1ab(Decay 3).

A.3. Proof of Theorem 3.2

  1. Let us consider the case of (Constant). We consider that the upper bound C1K+C2b in Theorem 3.1 is equal to ϵ2. This implies that K=C1bϵ2bC2 achieves an ϵ–approximation. A discussion similar to the one showing that K=C1bϵ2bC2 is an ϵ–approximation ensures that the assertion in Theorem 3.2(i) is true.

  2. It is sufficient to prove that K=K(b)<0 and K=K(b)>0 hold.

(Constant): Let K=C1bϵ2bC2. Then, we have that K=C1(ϵ2bC2)ϵ2C1b(ϵ2bC2)2=C1C2(ϵ2bC2)2<0,K=2ϵ2C1C2(ϵ2bC2)(ϵ2bC2)4=2ϵ2C1C2((C1K+C2b)bC2)(ϵ2bC2)4=2ϵ2C12C22K(ϵ2bC2)4>0. (Decay 1): Let K=(1ϵ2(D1+D2(12a)b))1a. Then, we have that K=1a{1ϵ2(D1+D2(12a)b)}1a1(D2ϵ2(12a)b2)=D2aϵ2(12a)b2{1ϵ2(D1+D2(12a)b)}1aa<0,K=2D2aϵ2(12a)b3{1ϵ2(D1+D2(12a)b)}1aa+2(1a)D2a2ϵ2(12a)b3{1ϵ2(D1+D2(12a)b)}1aa1D2ϵ2(12a)b2>0. (Decay 2): Let K=(bD1+D2bϵ2D2)2. Then, we have that K=2D1(bD1+D2)(bϵ2D2)22ϵ2(bϵ2D2)(bD1+D2)2(bϵ2D2)4, which, together with bϵ2D2>0 and D1>ϵ2, implies that (bϵ2D2)3K=2D1(bD1+D2)(bϵ2D2)2ϵ2(bD1+D2)2=2(bD1+D2){D1(bϵ2D2)ϵ2(bD1+D2)}=2(bD1+D2)(D1D2+ϵ2D2)=2D2(bD1+D2)(D1+ϵ2)<0. Moreover, K=2D1D2(D1ϵ2)(bϵ2D2)3+6D2ϵ2(bϵ2D2)2(bD1+D2)(D1ϵ2)(bϵ2D2)6, which implies that (bϵ2D2)4K=2D1D2(D1ϵ2)(bϵ2D2)+6D2ϵ2(bD1+D2)(D1ϵ2)=2D2(D1ϵ2){D1(bϵ2D2)+3ϵ2(bD1+D2)}=2D2(D1ϵ2)(2D1ϵ2b+D1D2+3D2ϵ2)>0. (Decay 3): Let K=(1ϵ2(D1+2aD2(2a1)b))11a. Then, we have that K=11a{1ϵ2(D1+2aD2(2a1)b)}11a1(2aD2ϵ2(2a1)b2)=2aD2ϵ2(1a)(2a1)b2{1ϵ2(D1+2aD2(2a1)b)}a1a<0,K=4aD2ϵ2(1a)(2a1)b3{1ϵ2(D1+2aD2(2a1)b)}a1a+2a2D2ϵ2(1a)2(2a1)b2{1ϵ2(D1+2aD2(2a1)b)}a1a12aD2ϵ2(2a1)b2>0.

A.4. Proof of Theorem 3.3

(Constant): Let N=C1b2ϵ2bC2. Then, we have that N=2C1b(ϵ2bC2)ϵ2C1b2(ϵ2bC2)2=C1b(ϵ2b2C2)(ϵ2bC2)2. If N=0, we have that ϵ2b2C2=0, i.e. b=2C2ϵ2. Moreover, N=(2ϵ2C1b2C1C2)(ϵ2bC2)22ϵ2(ϵ2bC2)(ϵ2C1b22C1C2b)(ϵ2bC2)4(ϵ2bC2)3N=(2ϵ2C1b2C1C2)(ϵ2bC2)2ϵ2(ϵ2C1b22C1C2b)=2C1C22>0, which implies that N is convex. Hence, there is a critical batch size b=2C2ϵ2>0 at which N is minimized.

(Decay 1): Let N = Kb. Then, we have that N=K+bK={1ϵ2(D1+D2(12a)b)}1a+1a{1ϵ2(D1+D2(12a)b)}1a1(D2ϵ2(12a)b2)b={1ϵ2(D1+D2(12a)b)}1a1{1ϵ2(D1+D2(12a)b)D2aϵ2(12a)b}. If N=0, we have that 1ϵ2(D1+D2(12a)b)D2aϵ2(12a)b=0,i.e. b=D2(a1)aD1(2a1). Moreover, N=K+(K+bK)=2K+bK=2D2aϵ2(12a)b2{1ϵ2(D1+D2(12a)b)}1aa+2D2aϵ2(12a)b2{1ϵ2(D1+D2(12a)b)}1aa+2(1a)D2a2ϵ2(12a)b2{1ϵ2(D1+D2(12a)b)}1aa1D2ϵ2(12a)b2=2(1a)D2a2ϵ2(12a)b2{1ϵ2(D1+D2(12a)b)}1aa1D2ϵ2(12a)b2>0, which implies that N is convex. Hence, there is a critical batch size b=D2(a1)aD1(2a1)>0.

(Decay 2): Let N = bK. Then, we have that N=K+bK=(bD1D2)2(bϵ2D2)22D2b(bD1+D2)(D1ϵ2)(bϵ2D2)3=bD1+D2(bϵ2D2)3{(bD1+D2)(bϵ2D2)2D2b(D1ϵ2)}=bD1+D2(bϵ2D2)3{D1ϵ2b2+3D2(ϵ2D1)bD22}. If N=0, we have that D1b+D2=0, i.e. b=D2D1<0. Moreover, N=2K+bK=4D2(bD1+D2)(D1ϵ2)(bϵ2D2)3+2D2b(D1ϵ2)(2D1ϵ2b+D1D2+3D2ϵ2)(bϵ2D2)4=2D2(D1ϵ2)(bϵ2D2)4{2(bD1+D2)(bϵ2D2)+b(2D1ϵ2b+D1D2+3D2ϵ2)}=2D2(D1ϵ2)(bϵ2D2)4(3D1D2b+D2ϵ2b+2D22)>0, which implies that N is convex. We can check that N(b)>0 for all b>D2ϵ2.

(Decay 3): Let N = bK. Then, we have that N=K+bK={1ϵ2(D1+2aD2(2a1)b)}11a2aD2ϵ2(1a)(2a1)b{1ϵ2(D1+2aD2(2a1)b)}11a1={1ϵ2(D1+2aD2(2a1)b)}11a1{1ϵ2(D1+2aD2(2a1)b)2aD2ϵ2(1a)(2a1)b}. If N=0, we have that 1ϵ2(D1+2aD2(2a1)b)2aD2ϵ2(1a)(2a1)b=0,i.e. b=2a2D2(2a1)(1a)D1. Moreover, N=2K+bK=2aD2ϵ2(1a)(2a1)b2{1ϵ2(D1+2aD2(2a1)b)}a1a+2aD2ϵ2(1a)(2a1)b2{1ϵ2(D1+2aD2(2a1)b)}a1a+2a2D2ϵ2(1a)2(2a1)b{1ϵ2(D1+2aD2(2a1)b)}2a11a2aD2ϵ2(2a1)b2=2a2D2ϵ2(1a)2(2a1)b{1ϵ2(D1+2aD2(2a1)b)}2a11a2aD2ϵ2(2a1)b2>0, which implies that N is convex. Hence, there is a critical batch size b=2a2D2(2a1)(1a)D1>0.

A.5. Proof of Theorem 3.4

Using K defined in Theorem 3.2 leads to the iteration complexity. For example, SGD using (Constant) satisfies N(b)=C1b2ϵ2bC2 (Theorem 3.3). Using the critical batch size b=2C2ϵ2 in (Equation4) leads to inf{N:mink[0:K1]E[∇f(θk)]ϵ}N(b)=4C1C2ϵ4,i.e. Nϵ=O(1ϵ4). A similar discussion, together with using N defined in Theorem 3.3 and the critical batch size b in (Equation4), leads to the SFO complexities of (Decay 1) and (Decay 3). Using N defined in Theorem 3.3 and a batch size b=D2+1ϵ2 leads to the SFO complexity of (Decay 2).