aboutsummaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/hittable.rb6
-rw-r--r--lib/material.rb31
-rw-r--r--lib/vec3.rb7
3 files changed, 41 insertions, 3 deletions
diff --git a/lib/hittable.rb b/lib/hittable.rb
index facdd14..2a433d8 100644
--- a/lib/hittable.rb
+++ b/lib/hittable.rb
@@ -2,12 +2,12 @@ class HitRecord
def initialize(point, t, ray, out_normal, material)
@point = point
@t = t
- front_face = ray.direction.dot(out_normal) < 0
- @normal = front_face ? out_normal : -out_normal
+ @front_face = ray.direction.dot(out_normal) < 0
+ @normal = @front_face ? out_normal : -out_normal
@material = material
end
- attr_accessor :point, :normal, :t, :material
+ attr_reader :point, :normal, :t, :front_face, :material
end
class Hittable
diff --git a/lib/material.rb b/lib/material.rb
index c935147..752b85f 100644
--- a/lib/material.rb
+++ b/lib/material.rb
@@ -36,3 +36,34 @@ class Metal < Material
end
end
end
+
+class Dielectric < Material
+ def initialize(ref_index)
+ @ref_index = ref_index
+ end
+
+ def attenuation
+ Colour.new(1.0, 1.0, 1.0)
+ end
+
+ def scatter(ray, record)
+ ri = record.front_face ? (1.0 / @ref_index) : @ref_index
+ unit_dir = ray.direction.unit
+ costheta = [(-unit_dir).dot(record.normal), 1.0].min
+ sintheta = (1.0 - costheta ** 2) ** 0.5
+
+ cannot_refract = ri * sintheta > 1.0
+ maybe_reflect_anyway = Dielectric.reflectance(costheta, ri) > rand
+
+ refr = cannot_refract || maybe_reflect_anyway ?
+ unit_dir.reflect(record.normal) :
+ unit_dir.refract(record.normal, ri)
+
+ Ray.new(record.point, refr)
+ end
+
+ def self.reflectance(costheta, ri)
+ r0 = ((1.0 - ri) / (1.0 + ri)) ** 2
+ r0 + (1.0 - r0) * (1.0 - costheta) ** 5
+ end
+end
diff --git a/lib/vec3.rb b/lib/vec3.rb
index 3069b65..478fe4c 100644
--- a/lib/vec3.rb
+++ b/lib/vec3.rb
@@ -60,6 +60,13 @@ class Vec3
self - normal * dot(normal) * 2
end
+ def refract(normal, etaratio)
+ costheta = [(-self).dot(normal), 1.0].min
+ rout_perp = (self + normal * costheta) * etaratio
+ rout_parr = normal * -((1.0 - rout_perp.mag_sqr).abs ** 0.5)
+ rout_perp + rout_parr
+ end
+
def in_unit_sphere?
mag_sqr < 1
end