commit c872a1404d3f398c40856afb7f526b99b477ea3a
parent 95d960267a79ce2544a930721150fb623c4764eb
Author: Alex Knauth <alexander@knauth.org>
Date: Thu, 12 May 2016 15:32:54 -0400
Merged in allow-generalization-covariant (pull request #23)
Allow generalization for covariant type variables
Diffstat:
9 files changed, 515 insertions(+), 85 deletions(-)
diff --git a/tapl/mlish.rkt b/tapl/mlish.rkt
@@ -6,7 +6,7 @@
;(reuse [inst sysf:inst] #:from "sysf.rkt")
(require (rename-in (only-in "sysf.rkt" inst) [inst sysf:inst]))
(provide inst)
-(require (only-in "ext-stlc.rkt" →?))
+(require (only-in "ext-stlc.rkt" → →?))
(require (only-in "sysf.rkt" ~∀ ∀ ∀? Λ))
(reuse × tup proj define-type-alias #:from "stlc+rec-iso.rkt")
(require (only-in "stlc+rec-iso.rkt" ~× ×?)) ; using current-type=? from here
@@ -23,6 +23,9 @@
(require (prefix-in stlc+cons: (only-in "stlc+cons.rkt" list)))
(require (prefix-in stlc+tup: (only-in "stlc+tup.rkt" tup)))
+(module+ test
+ (require (for-syntax rackunit)))
+
(provide → →/test match2 define-type)
;; ML-like language
@@ -31,6 +34,25 @@
;; - pattern matching
;; - (local) type inference
+;; creating possibly polymorphic types
+;; ?∀ only wraps a type in a forall if there's at least one type variable
+(define-syntax ?∀
+ (lambda (stx)
+ (syntax-case stx ()
+ [(?∀ () body)
+ #'body]
+ [(?∀ (X ...) body)
+ #'(∀ (X ...) body)])))
+
+;; ?Λ only wraps an expression in a Λ if there's at least one type variable
+(define-syntax ?Λ
+ (lambda (stx)
+ (syntax-case stx ()
+ [(?Λ () body)
+ #'body]
+ [(?Λ (X ...) body)
+ #'(Λ (X ...) body)])))
+
(begin-for-syntax
;; matching possibly polymorphic types
(define-syntax ~?∀
@@ -43,26 +65,91 @@
(~parse vars-pat #'())
body-pat))]))))
- ;; type inference constraint solving
- (define (compute-constraint τ1-τ2)
- (syntax-parse τ1-τ2
- [(X:id τ) #'((X τ))]
- [((~Any tycons1 τ1 ...) (~Any tycons2 τ2 ...))
- #:when (typecheck? #'tycons1 #'tycons2)
- (compute-constraints #'((τ1 τ2) ...))]
- ; should only be monomorphic?
- [((~∀ () (~ext-stlc:→ τ1 ...)) (~∀ () (~ext-stlc:→ τ2 ...)))
- (compute-constraints #'((τ1 τ2) ...))]
- [_ #'()]))
- (define (compute-constraints τs)
- (stx-appendmap compute-constraint τs))
-
- (define (solve-constraint x-τ)
- (syntax-parse x-τ
- [(X:id τ) #'((X τ))]
- [_ #'()]))
- (define (solve-constraints cs)
- (stx-appendmap compute-constraint cs))
+ ;; add-constraints :
+ ;; (Listof Id) (Listof (List Id Type)) (Stx-Listof (Stx-List Stx Stx)) -> (Listof (List Id Type))
+ ;; Adds a new set of constaints to a substituion, using the type
+ ;; unification algorithm for local type inference.
+ (define (add-constraints Xs substs new-cs [orig-cs new-cs])
+ (define Xs* (stx->list Xs))
+ (define Ys (stx-map stx-car substs))
+ (define-syntax-class var
+ [pattern x:id #:when (member #'x Xs* free-identifier=?)])
+ (syntax-parse new-cs
+ [() substs]
+ [([a:var b] . rst)
+ (cond
+ [(member #'a Ys free-identifier=?)
+ ;; There are two cases.
+ ;; Either #'a already maps to #'b or an equivalent type,
+ ;; or #'a already maps to a type that conflicts with #'b.
+ ;; In either case, whatever #'a maps to must be equivalent
+ ;; to #'b, so add that to the constraints.
+ (add-constraints
+ Xs
+ substs
+ (cons (list (lookup #'a substs) #'b)
+ #'rst)
+ orig-cs)]
+ [else
+ (add-constraints
+ Xs*
+ ;; Add the mapping #'a -> #'b to the substitution,
+ (cons (list #'a #'b)
+ (for/list ([subst (in-list (stx->list substs))])
+ (list (stx-car subst)
+ (inst-type (list #'b) (list #'a) (stx-cadr subst)))))
+ ;; and substitute that in each of the constraints.
+ (for/list ([c (in-list (syntax->list #'rst))])
+ (list (inst-type (list #'b) (list #'a) (stx-car c))
+ (inst-type (list #'b) (list #'a) (stx-cadr c))))
+ orig-cs)])]
+ [([a b:var] . rst)
+ (add-constraints Xs*
+ substs
+ #'([b a] . rst)
+ orig-cs)]
+ [([a b] . rst)
+ ;; If #'a and #'b are base types, check that they're equal.
+ ;; Identifers not within Xs count as base types.
+ ;; If #'a and #'b are constructed types, check that the
+ ;; constructors are the same, add the sub-constraints, and
+ ;; recur.
+ ;; Otherwise, raise an error.
+ (cond
+ [(identifier? #'a)
+ ;; #'a is an identifier, but not a var, so it is considered
+ ;; a base type. We also know #'b is not a var, so #'b has
+ ;; to be the same "identifier base type" as #'a.
+ (unless (and (identifier? #'b) (free-identifier=? #'a #'b))
+ (type-error #:src (get-orig #'a)
+ #:msg (format "couldn't unify ~~a and ~~a\n expected: ~a\n given: ~a"
+ (string-join (map type->str (stx-map stx-car orig-cs)) ", ")
+ (string-join (map type->str (stx-map stx-cadr orig-cs)) ", "))
+ #'a #'b))
+ (add-constraints Xs*
+ substs
+ #'rst
+ orig-cs)]
+ [else
+ (syntax-parse #'[a b]
+ [_
+ #:when (typecheck? #'a #'b)
+ (add-constraints Xs
+ substs
+ #'rst
+ orig-cs)]
+ [((~Any tycons1 τ1 ...) (~Any tycons2 τ2 ...))
+ #:when (typecheck? #'tycons1 #'tycons2)
+ (add-constraints Xs
+ substs
+ #'((τ1 τ2) ... . rst)
+ orig-cs)]
+ [else
+ (type-error #:src (get-orig #'a)
+ #:msg (format "couldn't unify ~~a and ~~a\n expected: ~a\n given: ~a"
+ (string-join (map type->str (stx-map stx-car orig-cs)) ", ")
+ (string-join (map type->str (stx-map stx-cadr orig-cs)) ", "))
+ #'a #'b)])])]))
(define (lookup x substs)
(syntax-parse substs
@@ -72,11 +159,11 @@
[(_ . rst) (lookup x #'rst)]
[() #f]))
- ;; find-unsolved-Xs : (Stx-Listof Id) Constraints -> (Listof Id)
- ;; finds the free Xs that aren't constrained by cs
- (define (find-unsolved-Xs Xs cs)
+ ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id)
+ ;; finds the free Xs in the type
+ (define (find-free-Xs Xs ty)
(for/list ([X (in-list (stx->list Xs))]
- #:when (not (lookup X cs)))
+ #:when (stx-contains-id? ty X))
X))
;; lookup-Xs/keep-unsolved : (Stx-Listof Id) Constraints -> (Listof Type-Stx)
@@ -90,33 +177,43 @@
;; tyXs = input and output types from fn type
;; ie (typeof e_fn) = (-> . tyXs)
;; It infers the types of arguments from left-to-right,
- ;; and it short circuits if it's done early.
+ ;; and it expands and returns all of the arguments.
;; It returns list of 3 values if successful, else throws a type error
- ;; - a list of the arguments that it expanded
- ;; - a list of the the un-constrained type variables
+ ;; - a list of all the arguments, expanded
+ ;; - a list of all the type variables
;; - the constraints for substituting the types
(define (solve Xs tyXs stx)
(syntax-parse tyXs
[(τ_inX ... τ_outX)
;; generate initial constraints with expected type and τ_outX
- #:with expected-ty (get-expected-type stx)
+ #:with (~?∀ Vs expected-ty) (and (get-expected-type stx)
+ ((current-type-eval) (get-expected-type stx)))
(define initial-cs
- (if (syntax-e #'expected-ty)
- (compute-constraint (list #'τ_outX ((current-type-eval) #'expected-ty)))
+ (if (and (syntax-e #'expected-ty) (stx-null? #'Vs))
+ (add-constraints Xs '() (list (list #'expected-ty #'τ_outX)))
#'()))
(syntax-parse stx
[(_ e_fn . args)
(define-values (as- cs)
(for/fold ([as- null] [cs initial-cs])
([a (in-list (syntax->list #'args))]
- [tyXin (in-list (syntax->list #'(τ_inX ...)))]
- #:break (empty? (find-unsolved-Xs Xs cs)))
- (define/with-syntax [a- ty_a] (infer+erase a))
+ [tyXin (in-list (syntax->list #'(τ_inX ...)))])
+ (define ty_in (inst-type/cs Xs cs tyXin))
+ (define/with-syntax [a- ty_a]
+ (infer+erase (if (empty? (find-free-Xs Xs ty_in))
+ (add-expected-ty a ty_in)
+ a)))
(values
(cons #'a- as-)
- (stx-append cs (compute-constraint (list tyXin #'ty_a))))))
+ (add-constraints Xs cs (list (list ty_in #'ty_a))
+ (list (list (inst-type/cs/orig
+ Xs cs ty_in
+ (λ (id1 id2)
+ (equal? (syntax->datum id1)
+ (syntax->datum id2))))
+ #'ty_a))))))
- (list (reverse as-) (find-unsolved-Xs Xs cs) cs)])]))
+ (list (reverse as-) Xs cs)])]))
(define (raise-app-poly-infer-error stx expected-tys given-tys e_fn)
(type-error #:src stx
@@ -130,6 +227,11 @@
;; identifier in Xs is associated with the ith type in tys-solved
(define (inst-type tys-solved Xs ty)
(substs tys-solved Xs ty))
+ ;; inst-type/orig : (Listof Type) (Listof Id) Type (Id Id -> Bool) -> Type
+ ;; like inst-type, but also substitutes within the orig property
+ (define (inst-type/orig tys-solved Xs ty [var=? free-identifier=?])
+ (add-orig (inst-type tys-solved Xs ty)
+ (substs (stx-map get-orig tys-solved) Xs (get-orig ty) var=?)))
;; inst-type/cs : (Stx-Listof Id) Constraints Type-Stx -> Type-Stx
;; Instantiates ty, substituting each identifier in Xs with its mapping in cs.
@@ -141,6 +243,56 @@
(define (inst-types/cs Xs cs tys)
(stx-map (lambda (t) (inst-type/cs Xs cs t)) tys))
+ ;; inst-type/cs/orig :
+ ;; (Stx-Listof Id) Constraints Type-Stx (Id Id -> Bool) -> Type-Stx
+ ;; like inst-type/cs, but also substitutes within the orig property
+ (define (inst-type/cs/orig Xs cs ty [var=? free-identifier=?])
+ (define tys-solved (lookup-Xs/keep-unsolved Xs cs))
+ (inst-type/orig tys-solved Xs ty var=?))
+ ;; inst-types/cs/orig :
+ ;; (Stx-Listof Id) Constraints (Stx-Listof Type-Stx) (Id Id -> Bool) -> (Listof Type-Stx)
+ ;; the plural version of inst-type/cs/orig
+ (define (inst-types/cs/orig Xs cs tys [var=? free-identifier=?])
+ (stx-map (lambda (t) (inst-type/cs/orig Xs cs t var=?)) tys))
+
+ ;; covariant-Xs? : Type -> Bool
+ ;; Takes a possibly polymorphic type, and returns true if all of the
+ ;; type variables are in covariant positions within the type, false
+ ;; otherwise.
+ (define (covariant-Xs? ty)
+ (syntax-parse ((current-type-eval) ty)
+ [(~?∀ Xs ty)
+ (for/and ([X (in-list (syntax->list #'Xs))])
+ (covariant-X? X #'ty))]))
+
+ ;; find-X-variance : Id Type -> Variance
+ ;; Returns the variance of X within the type ty
+ (define (find-X-variance X ty)
+ (syntax-parse ty
+ [A:id #:when (free-identifier=? #'A X) covariant]
+ [(~Any tycons) irrelevant]
+ [(~?∀ () (~Any tycons τ ...))
+ #:when (get-arg-variances #'tycons)
+ #:when (stx-length=? #'[τ ...] (get-arg-variances #'tycons))
+ (for/fold ([acc irrelevant])
+ ([τ (in-list (syntax->list #'[τ ...]))]
+ [arg-variance (in-list (get-arg-variances #'tycons))])
+ (variance-join
+ acc
+ (variance-compose arg-variance (find-X-variance X τ))))]
+ [ty #:when (not (stx-contains-id? #'ty X)) irrelevant]
+ [_ invariant]))
+
+ ;; covariant-X? : Id Type -> Bool
+ ;; Returns true if every place X appears in ty is a covariant position, false otherwise.
+ (define (covariant-X? X ty)
+ (variance-covariant? (find-X-variance X ty)))
+
+ ;; contravariant-X? : Id Type -> Bool
+ ;; Returns true if every place X appears in ty is a contravariant position, false otherwise.
+ (define (contravariant-X? X ty)
+ (variance-contravariant? (find-X-variance X ty)))
+
;; compute unbound tyvars in one unexpanded type ty
(define (compute-tyvar1 ty)
(syntax-parse ty
@@ -182,8 +334,8 @@
;; TODO: check that specified return type is correct
;; - currently cannot do it here; to do the check here, need all types of
;; top-lvl fns, since they can call each other
- #:with (~and ty_fn_expected (~∀ _ (~ext-stlc:→ _ ... out_expected)))
- ((current-type-eval) #'(∀ Ys (ext-stlc:→ τ+orig ...)))
+ #:with (~and ty_fn_expected (~?∀ _ (~ext-stlc:→ _ ... out_expected)))
+ ((current-type-eval) #'(?∀ Ys (ext-stlc:→ τ+orig ...)))
#`(begin
(define-syntax f (make-rename-transformer (⊢ g : ty_fn_expected)))
(define g
@@ -200,15 +352,15 @@
;; TODO: check that specified return type is correct
;; - currently cannot do it here; to do the check here, need all types of
;; top-lvl fns, since they can call each other
- #:with (~and ty_fn_expected (~∀ _ (~ext-stlc:→ _ ... out_expected)))
+ #:with (~and ty_fn_expected (~?∀ _ (~ext-stlc:→ _ ... out_expected)))
(set-stx-prop/preserved
- ((current-type-eval) #'(∀ Ys (ext-stlc:→ τ+orig ...)))
+ ((current-type-eval) #'(?∀ Ys (ext-stlc:→ τ+orig ...)))
'orig
(list #'(→ τ+orig ...)))
#`(begin
(define-syntax f (make-rename-transformer (⊢ g : ty_fn_expected)))
(define g
- (Λ Ys (ext-stlc:λ ([x : τ] ...) (ext-stlc:begin e_body ... e_ann)))))])
+ (?Λ Ys (ext-stlc:λ ([x : τ] ...) (ext-stlc:begin e_body ... e_ann)))))])
;; define-type -----------------------------------------------
;; TODO: should validate τ as part of define-type definition (before it's used)
@@ -275,30 +427,44 @@
#'(StructName ...) #'((fld ...) ...))
#:with (Cons? ...) (stx-map mk-? #'(StructName ...))
#:with (exposed-Cons? ...) (stx-map mk-? #'(Cons ...))
+ #:do [(define expanded-tys
+ (for/list ([τ (in-list (syntax->list #'[τ ... ...]))])
+ (with-handlers ([exn:fail:syntax? (λ (e) #false)])
+ ((current-type-eval) #`(∀ (X ...) #,τ)))))]
+ #:with [arg-variance ...]
+ (for/list ([i (in-range (length (syntax->list #'[X ...])))])
+ (for/fold ([acc irrelevant])
+ ([ty (in-list expanded-tys)])
+ (cond [ty
+ (define/syntax-parse (~?∀ Xs τ) ty)
+ (define X (list-ref (syntax->list #'Xs) i))
+ (variance-join acc (find-X-variance X #'τ))]
+ [else invariant])))
#`(begin
(define-syntax (NameExtraInfo stx)
(syntax-parse stx
[(_ X ...) #'(('Cons 'StructName Cons? [acc τ] ...) ...)]))
(define-type-constructor Name
#:arity = #,(stx-length #'(X ...))
+ #:arg-variances (λ (stx) (list 'arg-variance ...))
#:extra-info 'NameExtraInfo
#:no-provide)
(struct StructName (fld ...) #:reflection-name 'Cons #:transparent) ...
(define-syntax (exposed-acc stx) ; accessor for records
(syntax-parse stx
- [_:id (⊢ acc (∀ (X ...) (ext-stlc:→ (Name X ...) τ)))]
+ [_:id (⊢ acc (?∀ (X ...) (ext-stlc:→ (Name X ...) τ)))]
[(o . rst) ; handle if used in fn position
#:with app (datum->syntax #'o '#%app)
#`(app
- #,(assign-type #'acc #'(∀ (X ...) (ext-stlc:→ (Name X ...) τ)))
+ #,(assign-type #'acc #'(?∀ (X ...) (ext-stlc:→ (Name X ...) τ)))
. rst)])) ... ...
(define-syntax (exposed-Cons? stx) ; predicates for each variant
(syntax-parse stx
- [_:id (⊢ Cons? (∀ (X ...) (ext-stlc:→ (Name X ...) Bool)))]
+ [_:id (⊢ Cons? (?∀ (X ...) (ext-stlc:→ (Name X ...) Bool)))]
[(o . rst) ; handle if used in fn position
#:with app (datum->syntax #'o '#%app)
#`(app
- #,(assign-type #'Cons? #'(∀ (X ...) (ext-stlc:→ (Name X ...) Bool)))
+ #,(assign-type #'Cons? #'(?∀ (X ...) (ext-stlc:→ (Name X ...) Bool)))
. rst)])) ...
(define-syntax (Cons stx)
(syntax-parse stx
@@ -319,7 +485,7 @@
(current-continuation-marks)))
#:with (NameExpander τ-expected-arg (... ...)) ((current-type-eval) #'τ-expected)
#'(C {τ-expected-arg (... ...)})]
- [_:id (⊢ StructName (∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))] ; HO fn
+ [_:id (⊢ StructName (?∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))] ; HO fn
[(C τs e_arg ...)
#:when (brace? #'τs) ; commit to this clause
#:with {~! τ_X:type (... ...)} #'τs
@@ -340,7 +506,7 @@
[(C . args) ; no type annotations, must infer instantiation
#:with StructName/ty
(set-stx-prop/preserved
- (⊢ StructName : (∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))
+ (⊢ StructName : (?∀ (X ...) (ext-stlc:→ τ ... (Name X ...))))
'orig
(list #'C))
; stx/loc transfers expected-type
@@ -651,19 +817,16 @@
(let ([x- (acc z)] ...) e_c-)] ...))
: τ_out)])])])
-(define-syntax → ; wrapping →
- (syntax-parser
- [(_ . rst) (set-stx-prop/preserved #'(∀ () (ext-stlc:→ . rst)) 'orig (list #'(→ . rst)))]))
; special arrow that computes free vars; for use with tests
; (because we can't write explicit forall
(define-syntax →/test
(syntax-parser
[(_ (~and Xs (X:id ...)) . rst)
#:when (brace? #'Xs)
- #'(∀ (X ...) (ext-stlc:→ . rst))]
+ #'(?∀ (X ...) (ext-stlc:→ . rst))]
[(_ . rst)
#:with Xs (compute-tyvars #'rst)
- #'(∀ Xs (ext-stlc:→ . rst))]))
+ #'(?∀ Xs (ext-stlc:→ . rst))]))
; redefine these to use lifted →
(define-primop + : (→ Int Int Int))
@@ -685,7 +848,7 @@
(define-primop even? : (→ Int Bool))
(define-primop odd? : (→ Int Bool))
-; all λs have type (∀ (X ...) (→ τ_in ... τ_out)), even monomorphic fns
+; all λs have type (?∀ (X ...) (→ τ_in ... τ_out))
(define-typed-syntax liftedλ #:export-as λ
[(_ (x:id ...+) body)
#:with (~?∀ Xs expected) (get-expected-type stx)
@@ -696,21 +859,21 @@
(type-error #:src stx #:msg
(format "expected a function of ~a arguments, got one with ~a arguments"
(stx-length #'[arg-ty ...] #'[x ...]))))]
- #`(Λ Xs (ext-stlc:λ ([x : arg-ty] ...) #,(add-expected-ty #'body #'body-ty)))]
+ #`(?Λ Xs (ext-stlc:λ ([x : arg-ty] ...) #,(add-expected-ty #'body #'body-ty)))]
[(_ args body)
#:with (~?∀ () (~ext-stlc:→ arg-ty ... body-ty)) (get-expected-type stx)
- #`(Λ () (ext-stlc:λ args #,(add-expected-ty #'body #'body-ty)))]
+ #`(?Λ () (ext-stlc:λ args #,(add-expected-ty #'body #'body-ty)))]
[(_ (~and x+tys ([_ (~datum :) ty] ...)) . body)
#:with Xs (compute-tyvars #'(ty ...))
;; TODO is there a way to have λs that refer to ids defined after them?
- #'(Λ Xs (ext-stlc:λ x+tys . body))])
+ #'(?Λ Xs (ext-stlc:λ x+tys . body))])
;; #%app --------------------------------------------------
(define-typed-syntax mlish:#%app #:export-as #%app
[(_ e_fn . e_args)
;; ) compute fn type (ie ∀ and →)
- #:with [e_fn- (~∀ Xs (~ext-stlc:→ . tyX_args))] (infer+erase #'e_fn)
+ #:with [e_fn- (~?∀ Xs (~ext-stlc:→ . tyX_args))] (infer+erase #'e_fn)
(cond
[(stx-null? #'Xs)
(syntax-parse #'(e_args tyX_args)
@@ -722,22 +885,17 @@
#'(ext-stlc:#%app e_fn/ty (add-expected e_arg τ_inX) ...)])]
[else
;; ) solve for type variables Xs
- (define/with-syntax ((e_arg1- ...) (unsolved-X ...) cs) (solve #'Xs #'tyX_args stx))
+ (define/with-syntax ((e_arg- ...) Xs* cs) (solve #'Xs #'tyX_args stx))
;; ) instantiate polymorphic function type
- (syntax-parse (inst-types/cs #'Xs #'cs #'tyX_args)
+ (syntax-parse (inst-types/cs #'Xs* #'cs #'tyX_args)
[(τ_in ... τ_out) ; concrete types
+ #:with (unsolved-X ...) (find-free-Xs #'Xs* #'τ_out)
;; ) arity check
#:fail-unless (stx-length=? #'(τ_in ...) #'e_args)
(mk-app-err-msg stx #:expected #'(τ_in ...)
#:note "Wrong number of arguments.")
- ;; ) compute argument types; re-use args expanded during solve
- #:with ([e_arg2- τ_arg2] ...) (let ([n (stx-length #'(e_arg1- ...))])
- (infers+erase
- (stx-map add-expected-ty
- (stx-drop #'e_args n) (stx-drop #'(τ_in ...) n))))
- #:with (τ_arg1 ...) (stx-map typeof #'(e_arg1- ...))
- #:with (τ_arg ...) #'(τ_arg1 ... τ_arg2 ...)
- #:with (e_arg- ...) #'(e_arg1- ... e_arg2- ...)
+ ;; ) compute argument types
+ #:with (τ_arg ...) (stx-map typeof #'(e_arg- ...))
;; ) typecheck args
#:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...))
(mk-app-err-msg stx
@@ -749,14 +907,23 @@
(define new-orig
(and old-orig
(substs
- (stx-map get-orig (lookup-Xs/keep-unsolved #'Xs #'cs)) #'Xs old-orig
+ (stx-map get-orig (lookup-Xs/keep-unsolved #'Xs* #'cs))
+ #'Xs*
+ old-orig
(lambda (x y)
(equal? (syntax->datum x) (syntax->datum y))))))
(set-stx-prop/preserved tyin 'orig (list new-orig)))
#'(τ_in ...)))
#:with τ_out* (if (stx-null? #'(unsolved-X ...))
#'τ_out
- (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn))
+ (syntax-parse #'τ_out
+ [(~?∀ (Y ...) τ_out)
+ (unless (→? #'τ_out)
+ (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn))
+ (for ([X (in-list (syntax->list #'(unsolved-X ...)))])
+ (unless (covariant-X? X #'τ_out)
+ (raise-app-poly-infer-error stx #'(τ_in ...) #'(τ_arg ...) #'e_fn)))
+ #'(∀ (unsolved-X ... Y ...) τ_out)]))
(⊢ (#%app e_fn- e_arg- ...) : τ_out*)])])]
[(_ e_fn . e_args) ; err case; e_fn is not a function
#:with [e_fn- τ_fn] (infer+erase #'e_fn)
@@ -814,7 +981,7 @@
;; threads
(define-typed-syntax thread
[(_ th)
- #:with (th- (~∀ () (~ext-stlc:→ τ_out))) (infer+erase #'th)
+ #:with (th- (~?∀ () (~ext-stlc:→ τ_out))) (infer+erase #'th)
(⊢ (thread th-) : Thread)])
(define-primop random : (→ Int Int))
@@ -1177,10 +1344,7 @@
[(_ e ty ...)
#:with [ee tyty] (infer+erase #'e)
#:with [e- ty_e] (infer+erase #'(sysf:inst e ty ...))
- #:with ty_out (if (→? #'ty_e)
- #'(∀ () ty_e)
- #'ty_e)
- (⊢ e- : ty_out)]))
+ (⊢ e- : ty_e)]))
(define-typed-syntax read
[(_)
@@ -1188,3 +1352,30 @@
(cond [(eof-object? x) ""]
[(number? x) (number->string x)]
[(symbol? x) (symbol->string x)])) : String)])
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+(module+ test
+ (begin-for-syntax
+ (check-true (covariant-Xs? #'Int))
+ (check-true (covariant-Xs? #'(stlc+box:Ref Int)))
+ (check-true (covariant-Xs? #'(→ Int Int)))
+ (check-true (covariant-Xs? #'(∀ (X) X)))
+ (check-false (covariant-Xs? #'(∀ (X) (stlc+box:Ref X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ X X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ X Int))))
+ (check-true (covariant-Xs? #'(∀ (X) (→ Int X))))
+ (check-true (covariant-Xs? #'(∀ (X) (→ (→ X Int) X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ (→ (→ X Int) Int) X))))
+ (check-false (covariant-Xs? #'(∀ (X) (→ (stlc+box:Ref X) Int))))
+ (check-false (covariant-Xs? #'(∀ (X Y) (→ X Y))))
+ (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) Y))))
+ (check-false (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Y Int)))))
+ (check-true (covariant-Xs? #'(∀ (X Y) (→ (→ X Int) (→ Int Y)))))
+ (check-false (covariant-Xs? #'(∀ (A B) (→ (→ Int (stlc+rec-iso:× A B))
+ (→ String (stlc+rec-iso:× A B))
+ (stlc+rec-iso:× A B)))))
+ (check-true (covariant-Xs? #'(∀ (A B) (→ (→ (stlc+rec-iso:× A B) Int)
+ (→ (stlc+rec-iso:× A B) String)
+ (stlc+rec-iso:× A B)))))
+ ))
diff --git a/tapl/stlc+tup.rkt b/tapl/stlc+tup.rkt
@@ -1,5 +1,7 @@
#lang s-exp "typecheck.rkt"
(extends "ext-stlc.rkt")
+
+(require (for-syntax racket/list))
;; Simply-Typed Lambda Calculus, plus tuples
;; Types:
@@ -9,7 +11,9 @@
;; - terms from ext-stlc.rkt
;; - tup and proj
-(define-type-constructor × #:arity >= 0)
+(define-type-constructor × #:arity >= 0
+ #:arg-variances (λ (stx)
+ (make-list (stx-length (stx-cdr stx)) covariant)))
(define-typed-syntax tup
[(_ e ...)
diff --git a/tapl/stlc.rkt b/tapl/stlc.rkt
@@ -2,6 +2,8 @@
(provide (for-syntax current-type=? types=?))
(provide (for-syntax mk-app-err-msg))
+(require (for-syntax racket/list))
+
;; Simply-Typed Lambda Calculus
;; - no base types; can't write any terms
;; Types: multi-arg → (1+)
@@ -66,7 +68,13 @@
(define-syntax-category type)
-(define-type-constructor → #:arity >= 1)
+(define-type-constructor → #:arity >= 1
+ #:arg-variances (λ (stx)
+ (syntax-parse stx
+ [(_ τ_in ... τ_out)
+ (append
+ (make-list (stx-length #'[τ_in ...]) contravariant)
+ (list covariant))])))
(define-typed-syntax λ
[(_ bvs:type-ctx e)
diff --git a/tapl/stx-utils.rkt b/tapl/stx-utils.rkt
@@ -1,5 +1,5 @@
#lang racket/base
-(require syntax/stx racket/list version/utils)
+(require syntax/stx syntax/parse racket/list version/utils)
(provide (all-defined-out))
(define (stx-cadr stx) (stx-car (stx-cdr stx)))
@@ -70,6 +70,9 @@
(define (generate-temporariesss stx)
(stx-map generate-temporariess stx))
+;; set-stx-prop/preserved : Stx Any Any -> Stx
+;; Returns a new syntax object with the prop property set to val. If preserved
+;; syntax properties are supported, this also marks the property as preserved.
(define REQUIRED-VERSION "6.5.0.4")
(define VERSION (version))
(define PRESERVED-STX-PROP-SUPPORTED? (version<=? REQUIRED-VERSION VERSION))
@@ -78,6 +81,16 @@
(syntax-property stx prop val #t)
(syntax-property stx prop val)))
+;; stx-contains-id? : Stx Id -> Boolean
+;; Returns true if stx contains the identifier x, false otherwise.
+(define (stx-contains-id? stx x)
+ (syntax-parse stx
+ [a:id (free-identifier=? #'a x)]
+ [(a . b)
+ (or (stx-contains-id? #'a x)
+ (stx-contains-id? #'b x))]
+ [_ #false]))
+
;; based on make-variable-like-transformer from syntax/transformer,
;; but using (#%app id ...) instead of ((#%expression id) ...)
(define (make-variable-like-transformer ref-stx)
diff --git a/tapl/tests/mlish-tests.rkt b/tapl/tests/mlish-tests.rkt
@@ -36,7 +36,7 @@
;; type err
(typecheck-fail (Cons 1 1)
- #:with-msg (expected "Int, (List Int)" #:given "Int, Int"))
+ #:with-msg "expected: \\(List Int\\)\n *given: Int")
;; check Nil still available as tyvar
(define (f11 [x : Nil] -> Nil) x)
@@ -55,7 +55,7 @@
(check-type g2 : (→/test (List Y) (List Y)))
(typecheck-fail (g2 1)
#:with-msg
- (expected "(List Y)" #:given "Int"))
+ "expected: \\(List Y\\)\n *given: Int")
;; todo? allow polymorphic nil?
(check-type (g2 (Nil {Int})) : (List Int) ⇒ (Nil {Int}))
@@ -113,7 +113,7 @@
(check-type (map add1 (Cons 1 (Cons 2 (Cons 3 Nil))))
: (List Int) ⇒ (Cons 2 (Cons 3 (Cons 4 Nil))))
(typecheck-fail (map add1 (Cons "1" Nil))
- #:with-msg (expected "Int, (List Int)" #:given "String, (List Int)"))
+ #:with-msg "expected: Int\n *given: String")
(check-type (map (λ ([x : Int]) (+ x 2)) (Cons 1 (Cons 2 (Cons 3 Nil))))
: (List Int) ⇒ (Cons 3 (Cons 4 (Cons 5 Nil))))
;; ; doesnt work yet: all lambdas need annotations
@@ -179,6 +179,24 @@
(check-type (build-list 5 (λ (x) (add1 (add1 x))))
: (List Int) ⇒ (Cons 6 (Cons 5 (Cons 4 (Cons 3 (Cons 2 Nil))))))
+(define (build-list/comp [i : Int] [n : Int] [nf : (→ Int Int)] [f : (→ Int X)] → (List X))
+ (if (= i n)
+ Nil
+ (Cons (f (nf i)) (build-list/comp (add1 i) n nf f))))
+
+(define built-list-1 (build-list/comp 0 3 (λ (x) (* 2 x)) add1))
+(define built-list-2 (build-list/comp 0 3 (λ (x) (* 2 x)) number->string))
+(check-type built-list-1 : (List Int) -> (Cons 1 (Cons 3 (Cons 5 Nil))))
+(check-type built-list-2 : (List String) -> (Cons "0" (Cons "2" (Cons "4" Nil))))
+
+(define (~>2 [a : A] [f : (→ A A)] [g : (→ A B)] → B)
+ (g (f a)))
+
+(define ~>2-result-1 (~>2 1 (λ (x) (* 2 x)) add1))
+(define ~>2-result-2 (~>2 1 (λ (x) (* 2 x)) number->string))
+(check-type ~>2-result-1 : Int -> 3)
+(check-type ~>2-result-2 : String -> "2")
+
(define (append [lst1 : (List X)] [lst2 : (List X)] → (List X))
(match lst1 with
[Nil -> lst2]
@@ -242,8 +260,7 @@
(typecheck-fail Nil #:with-msg "add annotations")
(typecheck-fail (Cons 1 (Nil {Bool}))
#:with-msg
- (expected "Int, (List Int)" #:given "Int, (List Bool)"
- #:note "Type error applying.*Cons"))
+ "expected: \\(List Int\\)\n *given: \\(List Bool\\)")
(typecheck-fail (Cons {Bool} 1 (Nil {Int}))
#:with-msg
(expected "Bool, (List Bool)" #:given "Int, (List Int)"
@@ -285,6 +302,8 @@
(None)
(Some A))
+(define (None* → (Option A)) None)
+
(check-type (match (tup 1 2) with [a b -> None]) : (Option Int) -> None)
(check-type
(match (list 1 2) with
@@ -380,6 +399,52 @@
(check-type ((inst nn2 Int (List Int) String) 1)
: (→ (× Int (→ (List Int) (List Int)) (List String))))
+(define (nn3 [x : X] -> (→ (× X (Option Y) (Option Z))))
+ (λ () (tup x None None)))
+(check-type (nn3 1) : (→/test (× Int (Option Y) (Option Z))))
+(check-type (nn3 1) : (→ (× Int (Option String) (Option (List Int)))))
+(check-type ((nn3 1)) : (× Int (Option String) (Option (List Int))))
+(check-type ((nn3 1)) : (× Int (Option (List Int)) (Option String)))
+;; test inst order
+(check-type ((inst (nn3 1) String (List Int))) : (× Int (Option String) (Option (List Int))))
+(check-type ((inst (nn3 1) (List Int) String)) : (× Int (Option (List Int)) (Option String)))
+
+(define (nn4 -> (→ (Option X)))
+ (λ () (None*)))
+(check-type (let ([x (nn4)])
+ x)
+ : (→/test (Option X)))
+
+(define (nn5 -> (→ (Ref (Option X))))
+ (λ () (ref (None {X}))))
+(typecheck-fail (let ([x (nn5)])
+ x)
+ #:with-msg "Could not infer instantiation of polymorphic function nn5.")
+
+(define (nn6 -> (→ (Option X)))
+ (let ([r (((inst nn5 X)))])
+ (λ () (deref r))))
+(check-type (nn6) : (→/test (Option X)))
+
+;; A is covariant, B is invariant.
+(define-type (Cps A B)
+ (cps (→ (→ A B) B)))
+(define (cps* [f : (→ (→ A B) B)] → (Cps A B))
+ (cps f))
+
+(define (nn7 -> (→ (Cps (Option A) B)))
+ (let ([r (((inst nn5 A)))])
+ (λ () (cps* (λ (k) (k (deref r)))))))
+(typecheck-fail (let ([x (nn7)])
+ x)
+ #:with-msg "Could not infer instantiation of polymorphic function nn7.")
+
+(define (nn8 -> (→ (Cps (Option A) Int)))
+ (nn7))
+(check-type (let ([x (nn8)])
+ x)
+ : (→/test (Cps (Option A) Int)))
+
(define-type (Result A B)
(Ok A)
(Error B))
@@ -389,6 +454,35 @@
(define (error [b : B] → (Result A B))
(Error b))
+(define (ok-fn [a : A] -> (→ (Result A B)))
+ (λ () (ok a)))
+(define (error-fn [b : B] -> (→ (Result A B)))
+ (λ () (error b)))
+
+(check-type (let ([x (ok-fn 1)])
+ x)
+ : (→/test (Result Int B)))
+(check-type (let ([x (error-fn "bad")])
+ x)
+ : (→/test (Result A String)))
+
+(define (nn9 [a : A] -> (→ (Result A (Ref B))))
+ (ok-fn a))
+(define (nn10 [a : A] -> (→ (Result A (Ref String))))
+ (nn9 a))
+(define (nn11 -> (→ (Result (Option A) (Ref String))))
+ (nn10 (None*)))
+
+(typecheck-fail (let ([x (nn9 1)])
+ x)
+ #:with-msg "Could not infer instantiation of polymorphic function nn9.")
+(check-type (let ([x (nn10 1)])
+ x)
+ : (→ (Result Int (Ref String))))
+(check-type (let ([x (nn11)])
+ x)
+ : (→/test (Result (Option A) (Ref String))))
+
(check-type (if (zero? (random 2))
(ok 0)
(error "didn't get a zero"))
@@ -453,6 +547,21 @@
(λ (b) (Error (Cons b Nil))))
: (Result (List Int) (List String)))
+(define (tup* [a : A] [b : B] -> (× A B))
+ (tup a b))
+
+(define (nn12 -> (→ (× (Option A) (Option B))))
+ (λ () (tup* (None*) (None*))))
+(check-type (let ([x (nn12)])
+ x)
+ : (→/test (× (Option A) (Option B))))
+
+(define (nn13 -> (→ (× (Option A) (Option (Ref B)))))
+ (nn12))
+(typecheck-fail (let ([x (nn13)])
+ x)
+ #:with-msg "Could not infer instantiation of polymorphic function nn13.")
+
;; records and automatically-defined accessors and predicates
(define-type (RecoTest X Y)
(RT1 [x : X] [y : Y] [z : String])
diff --git a/tapl/tests/mlish/alex.mlish b/tapl/tests/mlish/alex.mlish
@@ -14,3 +14,12 @@
(check-type try : (→/test X (→ X Y) X))
+(define (accept-A×A [pair : (× A A)] → (× A A))
+ pair)
+
+(typecheck-fail (accept-A×A (tup 8 "ate"))
+ #:with-msg "couldn't unify Int and String\n *expected: \\(× A A\\)\n *given: \\(× Int String\\)")
+
+(typecheck-fail (ann (accept-A×A (tup 8 "ate")) : (× String String))
+ #:with-msg "expected: \\(× String String\\)\n *given: \\(× Int String\\)")
+
diff --git a/tapl/tests/mlish/match2.mlish b/tapl/tests/mlish/match2.mlish
@@ -62,7 +62,7 @@
(typecheck-fail
(match2 (B 1) with
[B x -> x])
- #:with-msg (expected "(× X X)" #:given "Int"))
+ #:with-msg "expected: \\(× X X\\)\n *given: Int")
(check-type
(match2 (B (tup 2 3)) with
diff --git a/tapl/tests/mlish/queens.mlish b/tapl/tests/mlish/queens.mlish
@@ -46,7 +46,7 @@
(check-type (map add1 (Cons 1 (Cons 2 (Cons 3 Nil))))
: (List Int) ⇒ (Cons 2 (Cons 3 (Cons 4 Nil))))
(typecheck-fail (map add1 (Cons "1" Nil))
- #:with-msg (expected "Int, (List Int)" #:given "String, (List Int)"))
+ #:with-msg "expected: Int\n *given: String")
(check-type (map (λ ([x : Int]) (+ x 2)) (Cons 1 (Cons 2 (Cons 3 Nil))))
: (List Int) ⇒ (Cons 3 (Cons 4 (Cons 5 Nil))))
;; ; doesnt work yet: all lambdas need annotations
diff --git a/tapl/typecheck.rkt b/tapl/typecheck.rkt
@@ -15,6 +15,9 @@
"stx-utils.rkt"))
(for-meta 2 (all-from-out racket/base syntax/parse racket/syntax)))
+(module+ test
+ (require (for-syntax rackunit)))
+
;; type checking functions/forms
;; General type checking strategy:
@@ -420,6 +423,48 @@
(define (brack? stx)
(define paren-shape/#f (syntax-property stx 'paren-shape))
(and paren-shape/#f (char=? paren-shape/#f #\[)))
+
+ (define (iff b1 b2)
+ (boolean=? b1 b2))
+
+ ;; Variance is (variance Boolean Boolean)
+ (struct variance (covariant? contravariant?) #:prefab)
+ (define irrelevant (variance #true #true))
+ (define covariant (variance #true #false))
+ (define contravariant (variance #false #true))
+ (define invariant (variance #false #false))
+ ;; variance-irrelevant? : Variance -> Boolean
+ (define (variance-irrelevant? v)
+ (and (variance-covariant? v) (variance-contravariant? v)))
+ ;; variance-invariant? : Variance -> Boolean
+ (define (variance-invariant? v)
+ (and (not (variance-covariant? v)) (not (variance-contravariant? v))))
+ ;; variance-join : Variance Variance -> Variance
+ (define (variance-join v1 v2)
+ (variance (and (variance-covariant? v1)
+ (variance-covariant? v2))
+ (and (variance-contravariant? v1)
+ (variance-contravariant? v2))))
+ ;; variance-compose : Variance Variance -> Variance
+ (define (variance-compose v1 v2)
+ (variance (or (variance-irrelevant? v1)
+ (variance-irrelevant? v2)
+ (and (variance-covariant? v1) (variance-covariant? v2))
+ (and (variance-contravariant? v1) (variance-contravariant? v2)))
+ (or (variance-irrelevant? v1)
+ (variance-irrelevant? v2)
+ (and (variance-covariant? v1) (variance-contravariant? v2))
+ (and (variance-contravariant? v1) (variance-covariant? v2)))))
+
+ ;; add-arg-variances : Id (Listof Variance) -> Id
+ ;; Takes a type constructor id and adds variance information about the arguments.
+ (define (add-arg-variances id arg-variances)
+ (set-stx-prop/preserved id 'arg-variances arg-variances))
+ ;; get-arg-variances : Id -> (U False (Listof Variance))
+ ;; Takes a type constructor id and gets the argument variance information.
+ (define (get-arg-variances id)
+ (syntax-property id 'arg-variances))
+
;; todo: abstract out the common shape of a type constructor,
;; i.e., the repeated pattern code in the functions below
(define (get-extra-info t)
@@ -482,6 +527,10 @@
#:defaults ([bvs-op #'=][bvs-n #'0]))
(~optional (~seq #:arr (~and (~parse has-annotations? #'#t) tycon))
#:defaults ([tycon #'void]))
+ (~optional (~seq #:arg-variances arg-variances-stx:expr)
+ #:defaults ([arg-variances-stx
+ #`(λ (stx-id) (for/list ([arg (in-list (stx->list (stx-cdr stx-id)))])
+ invariant))]))
(~optional (~seq #:extra-info extra-info)
#:defaults ([extra-info #'void]))
(~optional (~and #:no-provide (~parse no-provide? #'#t))))
@@ -532,6 +581,7 @@
#:msg
"Expected ~a type, got: ~a"
#'τ #'any))))])))
+ (define arg-variances arg-variances-stx)
(define (τ? t)
(and (stx-pair? t)
(syntax-parse t
@@ -565,10 +615,11 @@
#:with k_result (if #,(attribute has-annotations?)
#'(tycon k_arg (... ...))
#'#%kind)
+ #:with τ-internal* (add-arg-variances #'τ-internal (arg-variances stx))
(add-orig
(assign-type
(syntax/loc stx
- (τ-internal (λ bvs- (#%expression extra-info) . τs-)))
+ (τ-internal* (λ bvs- (#%expression extra-info) . τs-)))
#'k_result)
#'(τ . args))]
;; else fail with err msg
@@ -701,3 +752,48 @@
(define (substs τs xs e [cmp bound-identifier=?])
(stx-fold (lambda (ty x res) (subst ty x res cmp)) e τs xs)))
+
+;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+(module+ test
+ (begin-for-syntax
+ (test-case "variance-join"
+ (test-case "joining with irrelevant doesn't change it"
+ (check-equal? (variance-join irrelevant irrelevant) irrelevant)
+ (check-equal? (variance-join irrelevant covariant) covariant)
+ (check-equal? (variance-join irrelevant contravariant) contravariant)
+ (check-equal? (variance-join irrelevant invariant) invariant))
+ (test-case "joining with invariant results in invariant"
+ (check-equal? (variance-join invariant irrelevant) invariant)
+ (check-equal? (variance-join invariant covariant) invariant)
+ (check-equal? (variance-join invariant contravariant) invariant)
+ (check-equal? (variance-join invariant invariant) invariant))
+ (test-case "joining a with a results in a"
+ (check-equal? (variance-join irrelevant irrelevant) irrelevant)
+ (check-equal? (variance-join covariant covariant) covariant)
+ (check-equal? (variance-join contravariant contravariant) contravariant)
+ (check-equal? (variance-join invariant invariant) invariant))
+ (test-case "joining covariant with contravariant results in invariant"
+ (check-equal? (variance-join covariant contravariant) invariant)
+ (check-equal? (variance-join contravariant covariant) invariant)))
+ (test-case "variance-compose"
+ (test-case "composing with covariant doesn't change it"
+ (check-equal? (variance-compose covariant irrelevant) irrelevant)
+ (check-equal? (variance-compose covariant covariant) covariant)
+ (check-equal? (variance-compose covariant contravariant) contravariant)
+ (check-equal? (variance-compose covariant invariant) invariant))
+ (test-case "composing with irrelevant results in irrelevant"
+ (check-equal? (variance-compose irrelevant irrelevant) irrelevant)
+ (check-equal? (variance-compose irrelevant covariant) irrelevant)
+ (check-equal? (variance-compose irrelevant contravariant) irrelevant)
+ (check-equal? (variance-compose irrelevant invariant) irrelevant))
+ (test-case "otherwise composing with invariant results in invariant"
+ (check-equal? (variance-compose invariant covariant) invariant)
+ (check-equal? (variance-compose invariant contravariant) invariant)
+ (check-equal? (variance-compose invariant invariant) invariant))
+ (test-case "composing with with contravariant flips covariant and contravariant"
+ (check-equal? (variance-compose contravariant covariant) contravariant)
+ (check-equal? (variance-compose contravariant contravariant) covariant)
+ (check-equal? (variance-compose contravariant irrelevant) irrelevant)
+ (check-equal? (variance-compose contravariant invariant) invariant)))
+ ))