Skip to content

Commit 366d7a1

Browse files
NthPortallrytz
authored andcommitted
Fix CVE-2022-36944 for LazyList
Backport fix for CVE-2022-36944 from 2.13. Code copy-pasted in a browser.
1 parent 53b8c17 commit 366d7a1

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

compat/src/main/scala-2.11_2.12/scala/collection/compat/immutable/LazyList.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import scala.collection.generic.{
3333
SeqFactory
3434
}
3535
import scala.collection.immutable.{LinearSeq, NumericRange}
36-
import scala.collection.mutable.{ArrayBuffer, Builder, StringBuilder}
36+
import scala.collection.mutable.{Builder, StringBuilder}
3737
import scala.language.implicitConversions
3838

3939
/** This class implements an immutable linked list that evaluates elements
@@ -516,10 +516,6 @@ final class LazyList[+A] private (private[this] var lazyState: () => LazyList.St
516516
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))
517517
} else super.++:(prefix)(bf)
518518

519-
private def prependedAllToLL[B >: A](prefix: Traversable[B]): LazyList[B] =
520-
if (knownIsEmpty) LazyList.from(prefix)
521-
else newLL(stateFromIteratorConcatSuffix(prefix.toIterator)(state))
522-
523519
/** @inheritdoc
524520
*
525521
* $preservesLaziness
@@ -1512,14 +1508,17 @@ object LazyList extends SeqFactory[LazyList] {
15121508

15131509
private[this] def readObject(in: ObjectInputStream): Unit = {
15141510
in.defaultReadObject()
1515-
val init = new ArrayBuffer[A]
1511+
val init = new mutable.ListBuffer[A]
15161512
var initRead = false
15171513
while (!initRead) in.readObject match {
15181514
case SerializeEnd => initRead = true
15191515
case a => init += a.asInstanceOf[A]
15201516
}
15211517
val tail = in.readObject().asInstanceOf[LazyList[A]]
1522-
coll = tail.prependedAllToLL(init)
1518+
// scala/scala#10118: caution that no code path can evaluate `tail.state`
1519+
// before the resulting LazyList is returned
1520+
val it = init.toList.iterator
1521+
coll = newLL(stateFromIteratorConcatSuffix(it)(tail.state))
15231522
}
15241523

15251524
private[this] def readResolve(): Any = coll

compat/src/test/scala-jvm/test/scala/collection/LazyListGCTest.scala

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,86 @@ class LazyListGCTest {
125125
def tapEach_takeRight_headOption_allowsGC(): Unit = {
126126
assertLazyListOpAllowsGC(_.tapEach(_).takeRight(2).headOption, _ => ())
127127
}
128+
129+
@Test
130+
def serialization(): Unit =
131+
if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.12"))) {
132+
import java.io._
133+
134+
def serialize(obj: AnyRef): Array[Byte] = {
135+
val buffer = new ByteArrayOutputStream
136+
val out = new ObjectOutputStream(buffer)
137+
out.writeObject(obj)
138+
buffer.toByteArray
139+
}
140+
141+
def deserialize(a: Array[Byte]): AnyRef = {
142+
val in = new ObjectInputStream(new ByteArrayInputStream(a))
143+
in.readObject
144+
}
145+
146+
def serializeDeserialize[T <: AnyRef](obj: T) = deserialize(serialize(obj)).asInstanceOf[T]
147+
148+
val l = LazyList.from(10)
149+
150+
val ld1 = serializeDeserialize(l)
151+
assertEquals(l.take(10).toList, ld1.take(10).toList)
152+
153+
l.tail.head
154+
val ld2 = serializeDeserialize(l)
155+
assertEquals(l.take(10).toList, ld2.take(10).toList)
156+
157+
LazyListGCTest.serializationForceCount = 0
158+
val u = LazyList
159+
.from(10)
160+
.map(x => {
161+
LazyListGCTest.serializationForceCount += 1; x
162+
})
163+
164+
def printDiff(): Unit = {
165+
val a = serialize(u)
166+
classOf[LazyList[_]]
167+
.getDeclaredField("scala$collection$compat$immutable$LazyList$$stateEvaluated")
168+
.setBoolean(u, true)
169+
val b = serialize(u)
170+
val i = a.zip(b).indexWhere(p => p._1 != p._2)
171+
println("difference: ")
172+
println(s"val from = ${a.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
173+
println(s"val to = ${b.slice(i - 10, i + 10).mkString("List[Byte](", ", ", ")")}")
174+
}
175+
176+
// to update this test, comment-out `LazyList.writeReplace` and run `printDiff`
177+
// printDiff()
178+
179+
val from = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 0, 115, 114, 0, 33, 106, 97,
180+
118, 97, 46)
181+
val to = List[Byte](83, 116, 97, 116, 101, 59, 120, 112, 0, 0, 1, 115, 114, 0, 33, 106, 97,
182+
118, 97, 46)
183+
184+
assertEquals(LazyListGCTest.serializationForceCount, 0)
185+
186+
u.head
187+
assertEquals(LazyListGCTest.serializationForceCount, 1)
188+
189+
val data = serialize(u)
190+
var i = data.indexOfSlice(from)
191+
to.foreach(x => {
192+
data(i) = x; i += 1
193+
})
194+
195+
val ud1 = deserialize(data).asInstanceOf[LazyList[Int]]
196+
197+
// this check failed before scala/scala#10118, deserialization triggered evaluation
198+
assertEquals(LazyListGCTest.serializationForceCount, 1)
199+
200+
ud1.tail.head
201+
assertEquals(LazyListGCTest.serializationForceCount, 2)
202+
203+
u.tail.head
204+
assertEquals(LazyListGCTest.serializationForceCount, 3)
205+
}
206+
}
207+
208+
object LazyListGCTest {
209+
var serializationForceCount = 0
128210
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy