commit b65c4adc94a3888517774bcfb888fce06e4de6e5
parent 47fe5ae232f2a543083c4eeeb75f07f82889ba55
Author: Stephen Chang <stchang@ccs.neu.edu>
Date: Thu, 17 Mar 2016 19:03:02 -0400
use expected type to help infer instantiation of an app
Diffstat:
3 files changed, 38 insertions(+), 1 deletion(-)
diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt
@@ -375,6 +375,9 @@
(provide →/test)
(define-syntax →/test
(syntax-parser
+ [(_ (~and Xs (X:id ...)) . rst)
+ #:when (brace? #'Xs)
+ #'(∀ (X ...) (ext-stlc:→ . rst))]
[(_ . rst)
(let L ([Xs #'()]) ; compute unbound ids; treat as tyvars
(with-handlers ([exn:fail:syntax:unbound?
@@ -433,7 +436,16 @@
(syntax->datum #'e_fn) (type->str #'τ_fn))
#:with (~∀ Xs (~ext-stlc:→ τ_inX ... τ_outX)) #'τ_fn
;; ) instantiate polymorphic fn type
- #:with (τ_solved ...) (solve #'Xs #'(τ_inX ...) (syntax/loc stx (e_fn e_arg ...)))
+ ; try to solve with expected-type first
+ #:with expected-ty (get-expected-type stx)
+ #:with maybe-solved
+ (and (syntax-e #'expected-ty)
+ (let ([cs (compute-constraints (list (list #'τ_outX ((current-type-eval) #'expected-ty))))])
+ (filter (lambda (x) x) (stx-map (λ (X) (lookup X cs)) #'Xs))))
+ ;; else use arg types
+ #:with (τ_solved ...) (if (and (syntax-e #'maybe-solved) (stx-length=? #'maybe-solved #'Xs))
+ #'maybe-solved
+ (solve #'Xs #'(τ_inX ...) (syntax/loc stx (e_fn e_arg ...))))
;; #:with cs (compute-constraints #'((τ_inX τ_arg) ...))
;; #:with (τ_solved ...) (filter (λ (x) x) (stx-map (λ (y) (lookup y #'cs)) #'(X ...)))
;; #:fail-unless (stx-length=? #'(X ...) #'(τ_solved ...))
diff --git a/tapl/tests/mlish/inst.mlish b/tapl/tests/mlish/inst.mlish
@@ -0,0 +1,24 @@
+#lang s-exp "../../mlish.rkt"
+(require "../rackunit-typechecking.rkt")
+
+;; tests for instantiation of polymorphic functions and constructors
+
+(define-type (Result A B)
+ (Ok A)
+ (Error B))
+
+(define {A B} (ok [a : A] -> (Result A B))
+ (Ok a))
+
+(check-type ok : (→/test {A B} A (Result A B))) ; test inferred
+(check-type (inst ok Int String) : (→/test Int (Result Int String)))
+
+(define (f -> (Result Int String))
+ (ok 1))
+
+(check-type f : (→/test (Result Int String)))
+
+(define (g -> (Result Int String))
+ (Ok 1))
+
+(check-type g : (→/test (Result Int String)))
diff --git a/tapl/tests/run-all-mlish-tests.rkt b/tapl/tests/run-all-mlish-tests.rkt
@@ -18,3 +18,4 @@
(require "mlish/term.mlish")
(require "mlish/find.mlish")
(require "mlish/alex.mlish")
+(require "mlish/inst.mlish")